MogensR commited on
Commit
dae1677
·
1 Parent(s): 521249b

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +194 -112
models/loaders/matanyone_loader.py CHANGED
@@ -1,10 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- MatAnyone Loader + Stateful Adapter
4
- - Loads the official model from Hugging Face.
5
- - Drives InferenceCore as intended: first-frame encode + warm-up, then propagation.
6
- - Normalizes inputs so conv2d never sees 5-D tensors.
7
- - Always outputs a 2-D, contiguous float32 mask [H,W] for OpenCV.
 
8
  """
9
 
10
  import os
@@ -20,23 +21,20 @@
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
-
24
  # ------------------------- Shape & dtype utilities ------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
27
- pref = (pref or "").lower() if pref else ""
28
  if pref.startswith("cuda"):
29
  return "cuda" if torch.cuda.is_available() else "cpu"
30
  if pref == "cpu":
31
  return "cpu"
32
  return "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
-
35
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
36
  if isinstance(x, torch.Tensor):
37
- return x.to(device)
38
- return torch.from_numpy(np.asarray(x)).to(device)
39
-
40
 
41
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
42
  """
@@ -51,7 +49,7 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
51
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
52
  x = x.float()
53
 
54
- # 5D [B,T,*,H,W] or [B,T,H,W,*] -> take first frame
55
  if x.ndim == 5:
56
  x = x[:, 0] # -> 4D
57
 
@@ -83,20 +81,16 @@ def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
83
  else:
84
  if x.shape[1] == 1:
85
  x = x.repeat(1, 3, 1, 1)
86
- x = x.clamp_(0.0, 1.0).to(torch.float32)
87
 
88
  return x
89
 
90
-
91
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
92
- """Prefer CHW for InferenceCore.step."""
93
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
94
  return img_bchw[0]
95
- return img_bchw # some builds may accept batched; we try CHW first
96
-
97
 
98
- def _to_1hw_mask(msk_b1hw: torch.Tensor) -> torch.Tensor:
99
- """Non-idx path expects [1,H,W] for single target."""
100
  if msk_b1hw is None:
101
  return None
102
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
@@ -105,19 +99,15 @@ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> torch.Tensor:
105
  return msk_b1hw
106
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
107
 
108
-
109
- def _resize_mask_to(img_bchw: torch.Tensor, mask_b1hw: torch.Tensor) -> torch.Tensor:
110
- if mask_b1hw is None:
111
  return None
112
- if img_bchw.shape[-2:] == mask_b1hw.shape[-2:]:
113
- return mask_b1hw
114
- return F.interpolate(mask_b1hw, size=img_bchw.shape[-2:], mode="nearest")
115
-
116
 
117
  def _to_2d_alpha_numpy(x) -> np.ndarray:
118
- """
119
- Convert probabilities/mattes to 2-D float32 [H,W] contiguous.
120
- """
121
  t = torch.as_tensor(x).float()
122
  while t.ndim > 2:
123
  if t.ndim == 3:
@@ -128,7 +118,6 @@ def _to_2d_alpha_numpy(x) -> np.ndarray:
128
  out = t.detach().cpu().numpy().astype(np.float32)
129
  return np.ascontiguousarray(out)
130
 
131
-
132
  def debug_shapes(tag: str, image, mask) -> None:
133
  def _info(name, v):
134
  try:
@@ -141,35 +130,42 @@ def _info(name, v):
141
  _info("image", image)
142
  _info("mask", mask)
143
 
144
-
145
  # ------------------------------ Stateful Adapter --------------------------- #
146
 
147
  class _MatAnyoneSession:
148
  """
149
- Minimal stateful controller around InferenceCore.
150
-
151
  Usage:
152
- # frame 0 (has initial coarse mask):
153
- alpha0 = session(frame0_rgb, mask0) # encode + warm-up predict
154
  # frames 1..N (no mask):
155
- alpha = session(frame_rgb) # propagate/refine
156
  """
157
- def __init__(self, core, device: str):
 
 
 
 
 
 
 
 
 
158
  self.core = core
159
  self.device = device
 
 
 
 
 
160
  self.started = False
161
 
162
- # discover supported step() kwargs
163
  try:
164
- self._step_sig = inspect.signature(self.core.step)
165
- self._has_first_frame_pred = "first_frame_pred" in self._step_sig.parameters
166
- self._has_idx_mask = "idx_mask" in self._step_sig.parameters
167
  except Exception:
168
- self._step_sig = None
169
  self._has_first_frame_pred = True
170
- self._has_idx_mask = True
171
-
172
- # discover output conversion helper
173
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
174
 
175
  def reset(self):
@@ -180,63 +176,23 @@ def reset(self):
180
  pass
181
  self.started = False
182
 
183
- def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
184
- """
185
- Returns a 2-D float32 alpha [H,W] suitable for OpenCV.
186
- Expects RGB image in HWC or similar; mask as [H,W] or broadcastable.
187
- """
188
- # Normalize inputs
189
- img_bchw = _to_bchw(image, self.device, is_mask=False) # [B,C,H,W]
190
- msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
191
- if msk_b1hw is not None:
192
- msk_b1hw = _resize_mask_to(img_bchw, msk_b1hw)
193
- img_chw = _to_chw_image(img_bchw)
194
- m_1hw = _to_1hw_mask(msk_b1hw) if msk_b1hw is not None else None
195
-
196
- try:
197
- if not self.started:
198
- if m_1hw is None:
199
- logger.warning("First frame arrived without a mask; returning neutral alpha.")
200
- return np.full(img_chw.shape[-2:], 0.5, dtype=np.float32)
201
-
202
- # 1) Encode target on first frame
203
- kwargs1 = {}
204
- if self._has_idx_mask:
205
- kwargs1["idx_mask"] = False
206
- _ = self.core.step(image=img_chw, mask=m_1hw, **kwargs1)
207
-
208
- # 2) First-frame warm-up prediction + memorize
209
- kwargs2 = {}
210
- if self._has_first_frame_pred:
211
- kwargs2["first_frame_pred"] = True
212
- out_prob = self.core.step(image=img_chw, **kwargs2)
213
-
214
- alpha = self._to_alpha(out_prob)
215
- self.started = True
216
- return _to_2d_alpha_numpy(alpha)
217
-
218
- # Subsequent frames: propagate without mask
219
- out_prob = self.core.step(image=img_chw)
220
- alpha = self._to_alpha(out_prob)
221
- return _to_2d_alpha_numpy(alpha)
222
-
223
- except Exception as e:
224
- logger.debug(traceback.format_exc())
225
- logger.warning(f"MatAnyone call failed; returning input mask as fallback: {e}")
226
- if m_1hw is not None:
227
- return _to_2d_alpha_numpy(m_1hw)
228
- return np.full(img_chw.shape[-2:], 0.5, dtype=np.float32)
229
 
230
  def _to_alpha(self, out_prob):
231
- """
232
- Convert core output to alpha. Prefer core.output_prob_to_mask(matting=True) if available.
233
- """
234
  if self._has_prob_to_mask:
235
  try:
236
  return self.core.output_prob_to_mask(out_prob, matting=True)
237
  except Exception:
238
  pass
239
- # Fallback heuristics
240
  t = torch.as_tensor(out_prob).float()
241
  if t.ndim == 3 and t.shape[0] >= 1:
242
  return t[0]
@@ -244,12 +200,123 @@ def _to_alpha(self, out_prob):
244
  return t
245
  return torch.full((1, 1), 0.5, dtype=torch.float32, device=t.device if t.is_cuda else "cpu")
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # -------------------------------- Loader ---------------------------------- #
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  class MatAnyoneLoader:
251
  """
252
- Official MatAnyone loader with stateful adapter.
253
  """
254
 
255
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
@@ -267,16 +334,14 @@ def _import_model_and_core(self):
267
  """
268
  Import MatAnyone + InferenceCore with resilient fallbacks (different dist layouts).
269
  """
270
- # Try several possible import paths to be robust
271
  model_cls = core_cls = None
272
  err_msgs = []
273
 
274
  # Candidates for model class
275
- model_paths = [
276
  ("matanyone.model.matanyone", "MatAnyone"),
277
  ("matanyone", "MatAnyone"),
278
- ]
279
- for mod, cls in model_paths:
280
  try:
281
  m = __import__(mod, fromlist=[cls])
282
  model_cls = getattr(m, cls)
@@ -285,11 +350,10 @@ def _import_model_and_core(self):
285
  err_msgs.append(f"model {mod}.{cls}: {e}")
286
 
287
  # Candidates for InferenceCore
288
- core_paths = [
289
  ("matanyone.inference.inference_core", "InferenceCore"),
290
  ("matanyone", "InferenceCore"),
291
- ]
292
- for mod, cls in core_paths:
293
  try:
294
  m = __import__(mod, fromlist=[cls])
295
  core_cls = getattr(m, cls)
@@ -312,9 +376,21 @@ def load(self) -> Optional[Any]:
312
  try:
313
  model_cls, core_cls = self._import_model_and_core()
314
 
 
 
 
 
315
  # Official pattern: model -> eval -> core(model, cfg=model.cfg)
316
  self.model = model_cls.from_pretrained(self.model_id)
317
- self.model = self.model.to(self.device).eval()
 
 
 
 
 
 
 
 
318
 
319
  # Some builds require cfg; fall back if not present
320
  try:
@@ -324,17 +400,28 @@ def load(self) -> Optional[Any]:
324
  else:
325
  self.core = core_cls(self.model)
326
  except TypeError:
327
- # signature without cfg
328
  self.core = core_cls(self.model)
329
 
330
- # Move core to device if it supports .to
331
  try:
332
  if hasattr(self.core, "to"):
333
  self.core.to(self.device)
334
  except Exception:
335
  pass
336
 
337
- self.adapter = _MatAnyoneSession(self.core, self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  self.load_time = time.time() - start
339
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
340
  return self.adapter
@@ -345,11 +432,6 @@ def load(self) -> Optional[Any]:
345
  return None
346
 
347
  def cleanup(self):
348
- if self.adapter:
349
- try:
350
- self.adapter.reset()
351
- except Exception:
352
- pass
353
  self.adapter = None
354
  self.core = None
355
  if self.model:
 
1
  #!/usr/bin/env python3
2
  """
3
+ MatAnyone Loader + Stateful Adapter (OOM-resilient)
4
+ - Canonical HF load (MatAnyone.from_pretrained -> InferenceCore(model, cfg))
5
+ - Mixed precision (bf16/fp16) with safe fallback to fp32
6
+ - Autocast + inference_mode around every call
7
+ - Auto downscale with progressive retry on OOM, then upsample alpha back
8
+ - Returns 2-D float32 [H,W] alpha for OpenCV
9
  """
10
 
11
  import os
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
24
  # ------------------------- Shape & dtype utilities ------------------------- #
25
 
26
  def _select_device(pref: str) -> str:
27
+ pref = (pref or "").lower()
28
  if pref.startswith("cuda"):
29
  return "cuda" if torch.cuda.is_available() else "cpu"
30
  if pref == "cpu":
31
  return "cpu"
32
  return "cuda" if torch.cuda.is_available() else "cpu"
33
 
 
34
  def _as_tensor_on_device(x, device: str) -> torch.Tensor:
35
  if isinstance(x, torch.Tensor):
36
+ return x.to(device, non_blocking=True)
37
+ return torch.from_numpy(np.asarray(x)).to(device, non_blocking=True)
 
38
 
39
  def _to_bchw(x, device: str, is_mask: bool = False) -> torch.Tensor:
40
  """
 
49
  elif x.dtype in (torch.int16, torch.int32, torch.int64):
50
  x = x.float()
51
 
52
+ # 5D -> take first time slice
53
  if x.ndim == 5:
54
  x = x[:, 0] # -> 4D
55
 
 
81
  else:
82
  if x.shape[1] == 1:
83
  x = x.repeat(1, 3, 1, 1)
84
+ x = x.clamp_(0.0, 1.0)
85
 
86
  return x
87
 
 
88
  def _to_chw_image(img_bchw: torch.Tensor) -> torch.Tensor:
 
89
  if img_bchw.ndim == 4 and img_bchw.shape[0] == 1:
90
  return img_bchw[0]
91
+ return img_bchw
 
92
 
93
+ def _to_1hw_mask(msk_b1hw: torch.Tensor) -> Optional[torch.Tensor]:
 
94
  if msk_b1hw is None:
95
  return None
96
  if msk_b1hw.ndim == 4 and msk_b1hw.shape[1] == 1:
 
99
  return msk_b1hw
100
  raise ValueError(f"Expected B1HW or 1HW, got {tuple(msk_b1hw.shape)}")
101
 
102
+ def _resize_bchw(x: Optional[torch.Tensor], size_hw: Tuple[int, int], is_mask=False) -> Optional[torch.Tensor]:
103
+ if x is None:
 
104
  return None
105
+ if x.shape[-2:] == size_hw:
106
+ return x
107
+ mode = "nearest" if is_mask else "bilinear"
108
+ return F.interpolate(x, size=size_hw, mode=mode, align_corners=False if mode == "bilinear" else None)
109
 
110
  def _to_2d_alpha_numpy(x) -> np.ndarray:
 
 
 
111
  t = torch.as_tensor(x).float()
112
  while t.ndim > 2:
113
  if t.ndim == 3:
 
118
  out = t.detach().cpu().numpy().astype(np.float32)
119
  return np.ascontiguousarray(out)
120
 
 
121
  def debug_shapes(tag: str, image, mask) -> None:
122
  def _info(name, v):
123
  try:
 
130
  _info("image", image)
131
  _info("mask", mask)
132
 
 
133
  # ------------------------------ Stateful Adapter --------------------------- #
134
 
135
  class _MatAnyoneSession:
136
  """
137
+ Stateful controller around InferenceCore with OOM-resilient inference.
 
138
  Usage:
139
+ # frame 0 (has mask):
140
+ alpha0 = session(frame0_rgb01, mask01)
141
  # frames 1..N (no mask):
142
+ alpha = session(frame_rgb01)
143
  """
144
+ def __init__(
145
+ self,
146
+ core,
147
+ device: str,
148
+ model_dtype: torch.dtype,
149
+ use_autocast: bool,
150
+ autocast_dtype: Optional[torch.dtype],
151
+ max_edge: int = 768,
152
+ target_pixels: int = 600_000, # ~775x775 cap by area
153
+ ):
154
  self.core = core
155
  self.device = device
156
+ self.model_dtype = model_dtype
157
+ self.use_autocast = use_autocast and (device == "cuda")
158
+ self.autocast_dtype = autocast_dtype if self.use_autocast else None
159
+ self.max_edge = int(max_edge)
160
+ self.target_pixels = int(target_pixels)
161
  self.started = False
162
 
163
+ # feature detection
164
  try:
165
+ sig = inspect.signature(self.core.step)
166
+ self._has_first_frame_pred = "first_frame_pred" in sig.parameters
 
167
  except Exception:
 
168
  self._has_first_frame_pred = True
 
 
 
169
  self._has_prob_to_mask = hasattr(self.core, "output_prob_to_mask")
170
 
171
  def reset(self):
 
176
  pass
177
  self.started = False
178
 
179
+ # ---- helpers ----
180
+ def _compute_scaled_size(self, h: int, w: int) -> Tuple[int, int, float]:
181
+ if h <= 0 or w <= 0:
182
+ return h, w, 1.0
183
+ s1 = min(1.0, self.max_edge / max(h, w))
184
+ s2 = min(1.0, (self.target_pixels / (h * w)) ** 0.5) if self.target_pixels > 0 else 1.0
185
+ s = min(s1, s2)
186
+ nh = max(1, int(round(h * s)))
187
+ nw = max(1, int(round(w * s)))
188
+ return nh, nw, s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  def _to_alpha(self, out_prob):
 
 
 
191
  if self._has_prob_to_mask:
192
  try:
193
  return self.core.output_prob_to_mask(out_prob, matting=True)
194
  except Exception:
195
  pass
 
196
  t = torch.as_tensor(out_prob).float()
197
  if t.ndim == 3 and t.shape[0] >= 1:
198
  return t[0]
 
200
  return t
201
  return torch.full((1, 1), 0.5, dtype=torch.float32, device=t.device if t.is_cuda else "cpu")
202
 
203
+ # ---- main call ----
204
+ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
205
+ """
206
+ Returns a 2-D float32 alpha [H,W]. On first call, provide a coarse mask.
207
+ Subsequent calls propagate without a mask.
208
+ """
209
+ # Boundary normalization
210
+ img_bchw = _to_bchw(image, self.device, is_mask=False) # [1,C,H,W]
211
+ msk_b1hw = _to_bchw(mask, self.device, is_mask=True) if mask is not None else None
212
+
213
+ H, W = img_bchw.shape[-2], img_bchw.shape[-1]
214
+ if msk_b1hw is not None:
215
+ msk_b1hw = _resize_bchw(msk_b1hw, (H, W), is_mask=True)
216
+
217
+ # dtype alignment for activations
218
+ img_bchw = img_bchw.to(self.model_dtype, non_blocking=True)
219
+
220
+ # initial scale + fallbacks
221
+ nh, nw, s = self._compute_scaled_size(H, W)
222
+ scales = [(nh, nw)]
223
+ if s < 1.0:
224
+ scales.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
225
+ scales.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
226
+
227
+ last_exc = None
228
+
229
+ for (th, tw) in scales:
230
+ try:
231
+ # downscale for inference if needed
232
+ img_in = _resize_bchw(img_bchw, (th, tw), is_mask=False)
233
+ msk_in = _resize_bchw(msk_b1hw, (th, tw), is_mask=True) if msk_b1hw is not None else None
234
+
235
+ img_chw = _to_chw_image(img_in)
236
+ m_1hw = _to_1hw_mask(msk_in) if msk_in is not None else None
237
+
238
+ # inference with autocast + inference_mode
239
+ with torch.inference_mode():
240
+ if self.use_autocast:
241
+ amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
242
+ else:
243
+ class _NoOp:
244
+ def __enter__(self): return None
245
+ def __exit__(self, *args): return False
246
+ amp_ctx = _NoOp()
247
+
248
+ with amp_ctx:
249
+ if not self.started:
250
+ if m_1hw is None:
251
+ logger.warning("First frame arrived without a mask; returning neutral alpha.")
252
+ return np.full((H, W), 0.5, dtype=np.float32)
253
+
254
+ # encode/memorize
255
+ _ = self.core.step(image=img_chw, mask=m_1hw)
256
+ # warm-up predict
257
+ if self._has_first_frame_pred:
258
+ out_prob = self.core.step(image=img_chw, first_frame_pred=True)
259
+ else:
260
+ out_prob = self.core.step(image=img_chw)
261
+ alpha = self._to_alpha(out_prob)
262
+ self.started = True
263
+ else:
264
+ out_prob = self.core.step(image=img_chw)
265
+ alpha = self._to_alpha(out_prob)
266
+
267
+ # upsample back to original resolution if scaled
268
+ if (th, tw) != (H, W):
269
+ alpha = torch.as_tensor(alpha).unsqueeze(0).unsqueeze(0).float()
270
+ alpha = F.interpolate(alpha, size=(H, W), mode="bilinear", align_corners=False)
271
+ alpha = alpha.squeeze(0).squeeze(0)
272
+
273
+ return _to_2d_alpha_numpy(alpha)
274
+
275
+ except torch.cuda.OutOfMemoryError as e:
276
+ last_exc = e
277
+ logger.warning(f"MatAnyone OOM at {th}x{tw}; retrying smaller. {e}")
278
+ torch.cuda.empty_cache()
279
+ continue
280
+ except Exception as e:
281
+ last_exc = e
282
+ logger.debug(traceback.format_exc())
283
+ logger.warning(f"MatAnyone call failed at {th}x{tw}; retrying smaller. {e}")
284
+ torch.cuda.empty_cache()
285
+ continue
286
+
287
+ # All attempts failed → return fallback
288
+ logger.warning(f"MatAnyone calls failed; returning input mask as fallback. {last_exc}")
289
+ if msk_b1hw is not None:
290
+ return _to_2d_alpha_numpy(msk_b1hw)
291
+ return np.full((H, W), 0.5, dtype=np.float32)
292
 
293
  # -------------------------------- Loader ---------------------------------- #
294
 
295
+ def _choose_precision(device: str) -> Tuple[torch.dtype, bool, Optional[torch.dtype]]:
296
+ """
297
+ Decide model+autocast dtypes.
298
+ Strategy:
299
+ - Prefer bf16 autocast if supported (Ampere+), keep weights bf16 if possible.
300
+ - Else use fp16 autocast, keep weights fp16 if safe.
301
+ - Else fp32 without autocast.
302
+ """
303
+ if device != "cuda":
304
+ return torch.float32, False, None
305
+
306
+ bf16_ok = hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
307
+ cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
308
+ fp16_ok = cc[0] >= 7 # Volta+
309
+
310
+ if bf16_ok:
311
+ return torch.bfloat16, True, torch.bfloat16
312
+ if fp16_ok:
313
+ return torch.float16, True, torch.float16
314
+ return torch.float32, False, None
315
+
316
+
317
  class MatAnyoneLoader:
318
  """
319
+ Official MatAnyone loader with stateful, OOM-resilient adapter.
320
  """
321
 
322
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
 
334
  """
335
  Import MatAnyone + InferenceCore with resilient fallbacks (different dist layouts).
336
  """
 
337
  model_cls = core_cls = None
338
  err_msgs = []
339
 
340
  # Candidates for model class
341
+ for mod, cls in [
342
  ("matanyone.model.matanyone", "MatAnyone"),
343
  ("matanyone", "MatAnyone"),
344
+ ]:
 
345
  try:
346
  m = __import__(mod, fromlist=[cls])
347
  model_cls = getattr(m, cls)
 
350
  err_msgs.append(f"model {mod}.{cls}: {e}")
351
 
352
  # Candidates for InferenceCore
353
+ for mod, cls in [
354
  ("matanyone.inference.inference_core", "InferenceCore"),
355
  ("matanyone", "InferenceCore"),
356
+ ]:
 
357
  try:
358
  m = __import__(mod, fromlist=[cls])
359
  core_cls = getattr(m, cls)
 
376
  try:
377
  model_cls, core_cls = self._import_model_and_core()
378
 
379
+ # pick precision strategy
380
+ model_dtype, use_autocast, autocast_dtype = _choose_precision(self.device)
381
+ logger.info(f"MatAnyone precision: weights={model_dtype}, autocast={use_autocast and autocast_dtype}")
382
+
383
  # Official pattern: model -> eval -> core(model, cfg=model.cfg)
384
  self.model = model_cls.from_pretrained(self.model_id)
385
+
386
+ # Try to move weights to selected dtype (safe try)
387
+ try:
388
+ self.model = self.model.to(self.device).to(model_dtype)
389
+ except Exception:
390
+ self.model = self.model.to(self.device)
391
+ # keep weights fp32; still benefit from autocast
392
+
393
+ self.model.eval()
394
 
395
  # Some builds require cfg; fall back if not present
396
  try:
 
400
  else:
401
  self.core = core_cls(self.model)
402
  except TypeError:
 
403
  self.core = core_cls(self.model)
404
 
 
405
  try:
406
  if hasattr(self.core, "to"):
407
  self.core.to(self.device)
408
  except Exception:
409
  pass
410
 
411
+ # tune scaling from env (optional)
412
+ max_edge = int(os.environ.get("MATANYONE_MAX_EDGE", "768"))
413
+ target_pixels = int(os.environ.get("MATANYONE_TARGET_PIXELS", "600000"))
414
+
415
+ self.adapter = _MatAnyoneSession(
416
+ self.core,
417
+ device=self.device,
418
+ model_dtype=model_dtype,
419
+ use_autocast=use_autocast,
420
+ autocast_dtype=autocast_dtype,
421
+ max_edge=max_edge,
422
+ target_pixels=target_pixels,
423
+ )
424
+
425
  self.load_time = time.time() - start
426
  logger.info(f"MatAnyone loaded in {self.load_time:.2f}s")
427
  return self.adapter
 
432
  return None
433
 
434
  def cleanup(self):
 
 
 
 
 
435
  self.adapter = None
436
  self.core = None
437
  if self.model: