Stable-X commited on
Commit
d0afedf
·
verified ·
1 Parent(s): 0886e8b

Delete mast3r

Browse files
mast3r/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/catmlp_dpt_head.py DELETED
@@ -1,123 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # MASt3R heads
6
- # --------------------------------------------------------
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- import mast3r.utils.path_to_dust3r # noqa
11
- from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
12
- from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
13
- import dust3r.utils.path_to_croco # noqa
14
- from models.blocks import Mlp # noqa
15
-
16
-
17
- def reg_desc(desc, mode):
18
- if 'norm' in mode:
19
- desc = desc / desc.norm(dim=-1, keepdim=True)
20
- else:
21
- raise ValueError(f"Unknown desc mode {mode}")
22
- return desc
23
-
24
-
25
- def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
26
- if desc_conf_mode is None:
27
- desc_conf_mode = conf_mode
28
- fmap = out.permute(0, 2, 3, 1) # B,H,W,D
29
- res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
30
- if conf_mode is not None:
31
- res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
32
- if desc_dim is not None:
33
- start = 3 + int(conf_mode is not None)
34
- res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
35
- if two_confs:
36
- res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
37
- else:
38
- res['desc_conf'] = res['conf'].clone()
39
- return res
40
-
41
-
42
- class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
43
- """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
44
- The input for both heads is a concatenation of Encoder and Decoder outputs
45
- """
46
-
47
- def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
48
- num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
49
- super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
50
- dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
51
- self.local_feat_dim = local_feat_dim
52
-
53
- patch_size = net.patch_embed.patch_size
54
- if isinstance(patch_size, tuple):
55
- assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
56
- patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
57
- assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
58
- patch_size = patch_size[0]
59
- self.patch_size = patch_size
60
-
61
- self.desc_mode = net.desc_mode
62
- self.has_conf = has_conf
63
- self.two_confs = net.two_confs # independent confs for 3D regr and descs
64
- self.desc_conf_mode = net.desc_conf_mode
65
- idim = net.enc_embed_dim + net.dec_embed_dim
66
-
67
- self.head_local_features = Mlp(in_features=idim,
68
- hidden_features=int(hidden_dim_factor * idim),
69
- out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
70
-
71
- def forward(self, decout, img_shape):
72
- # pass through the heads
73
- pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
74
-
75
- # recover encoder and decoder outputs
76
- enc_output, dec_output = decout[0], decout[-1]
77
- cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
78
- H, W = img_shape
79
- B, S, D = cat_output.shape
80
-
81
- # extract local_features
82
- local_features = self.head_local_features(cat_output) # B,S,D
83
- local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
84
- local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
85
-
86
- # post process 3D pts, descriptors and confidences
87
- out = torch.cat([pts3d, local_features], dim=1)
88
- if self.postprocess:
89
- out = self.postprocess(out,
90
- depth_mode=self.depth_mode,
91
- conf_mode=self.conf_mode,
92
- desc_dim=self.local_feat_dim,
93
- desc_mode=self.desc_mode,
94
- two_confs=self.two_confs,
95
- desc_conf_mode=self.desc_conf_mode)
96
- return out
97
-
98
-
99
- def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
100
- """" build a prediction head for the decoder
101
- """
102
- if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
103
- local_feat_dim = int(output_mode[10:])
104
- assert net.dec_depth > 9
105
- l2 = net.dec_depth
106
- feature_dim = 256
107
- last_dim = feature_dim // 2
108
- out_nchan = 3
109
- ed = net.enc_embed_dim
110
- dd = net.dec_embed_dim
111
- return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
112
- num_channels=out_nchan + has_conf,
113
- feature_dim=feature_dim,
114
- last_dim=last_dim,
115
- hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
116
- dim_tokens=[ed, dd, dd, dd],
117
- postprocess=postprocess,
118
- depth_mode=net.depth_mode,
119
- conf_mode=net.conf_mode,
120
- head_type='regression')
121
- else:
122
- raise NotImplementedError(
123
- f"unexpected {head_type=} and {output_mode=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/cloud_opt/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/cloud_opt/sparse_ga.py DELETED
@@ -1,1001 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # MASt3R Sparse Global Alignement
6
- # --------------------------------------------------------
7
- from tqdm import tqdm
8
- import roma
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import numpy as np
13
- import os
14
- from collections import namedtuple
15
- from functools import lru_cache
16
- from scipy import sparse as sp
17
-
18
- from mast3r.utils.misc import mkdir_for, hash_md5
19
- from mast3r.cloud_opt.utils.losses import gamma_loss
20
- from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
21
- from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
22
-
23
- import mast3r.utils.path_to_dust3r # noqa
24
- from dust3r.utils.geometry import inv, geotrf # noqa
25
- from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
26
- from dust3r.post_process import estimate_focal_knowing_depth # noqa
27
- from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
28
- from dust3r.cloud_opt.base_opt import clean_pointcloud
29
- from dust3r.viz import SceneViz
30
-
31
-
32
- class SparseGA():
33
- def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
34
- def fetch_img(im):
35
- def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
36
- for im1, im2 in pairs_in:
37
- if im1['instance'] == im:
38
- return torgb(im1['img'])
39
- if im2['instance'] == im:
40
- return torgb(im2['img'])
41
- self.canonical_paths = canonical_paths
42
- self.img_paths = img_paths
43
- self.imgs = [fetch_img(img) for img in img_paths]
44
- self.intrinsics = res_fine['intrinsics']
45
- self.cam2w = res_fine['cam2w']
46
- self.depthmaps = res_fine['depthmaps']
47
- self.pts3d = res_fine['pts3d']
48
- self.pts3d_colors = []
49
- self.working_device = self.cam2w.device
50
- for i in range(len(self.imgs)):
51
- im = self.imgs[i]
52
- x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
53
- self.pts3d_colors.append(im[y, x])
54
- assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
55
- self.n_imgs = len(self.imgs)
56
-
57
- def get_focals(self):
58
- return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
59
-
60
- def get_principal_points(self):
61
- return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
62
-
63
- def get_im_poses(self):
64
- return self.cam2w
65
-
66
- def get_sparse_pts3d(self):
67
- return self.pts3d
68
-
69
- def get_dense_pts3d(self, clean_depth=True, subsample=8):
70
- assert self.canonical_paths, 'cache_path is required for dense 3d points'
71
- device = self.cam2w.device
72
- confs = []
73
- base_focals = []
74
- anchors = {}
75
- for i, canon_path in enumerate(self.canonical_paths):
76
- (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
77
- confs.append(conf)
78
- base_focals.append(focal)
79
-
80
- H, W = conf.shape
81
- pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
82
- idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
83
- anchors[i] = (pixels, idxs[i], offsets[i])
84
-
85
- # densify sparse depthmaps
86
- pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
87
- d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
88
-
89
- if clean_depth:
90
- confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d)
91
-
92
- return pts3d, depthmaps, confs
93
-
94
- def get_pts3d_colors(self):
95
- return self.pts3d_colors
96
-
97
- def get_depthmaps(self):
98
- return self.depthmaps
99
-
100
- def get_masks(self):
101
- return [slice(None, None) for _ in range(len(self.imgs))]
102
-
103
- def show(self, show_cams=True):
104
- pts3d, _, confs = self.get_dense_pts3d()
105
- show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
106
- [p.clip(min=-50, max=50) for p in pts3d],
107
- masks=[c > 1 for c in confs])
108
-
109
-
110
- def convert_dust3r_pairs_naming(imgs, pairs_in):
111
- for pair_id in range(len(pairs_in)):
112
- for i in range(2):
113
- pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
114
- return pairs_in
115
-
116
-
117
- def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
118
- device='cuda', dtype=torch.float32, **kw):
119
- """ Sparse alignment with MASt3R
120
- imgs: list of image paths
121
- cache_path: path where to dump temporary files (str)
122
-
123
- lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
124
- lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
125
-
126
- lora_depth: smart dimensionality reduction with depthmaps
127
- """
128
- # Convert pair naming convention from dust3r to mast3r
129
- pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
130
- # forward pass
131
- pairs, cache_path = forward_mast3r(pairs_in, model,
132
- cache_path=cache_path, subsample=subsample,
133
- desc_conf=desc_conf, device=device)
134
-
135
- # extract canonical pointmaps
136
- tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
137
- prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
138
-
139
- # compute minimal spanning tree
140
- mst = compute_min_spanning_tree(pairwise_scores)
141
-
142
- # remove all edges not in the spanning tree?
143
- # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
144
- # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
145
-
146
- # smartly combine all usefull data
147
- imsizes, pps, base_focals, core_depth, anchors, corres, corres2d = \
148
- condense_data(imgs, tmp_pairs, canonical_views, dtype)
149
-
150
- imgs, res_coarse, res_fine = sparse_scene_optimizer(
151
- imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths,
152
- mst, cache_path=cache_path, device=device, dtype=dtype, **kw)
153
- return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
154
-
155
-
156
- def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
157
- preds_21, canonical_paths, mst, cache_path,
158
- lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
159
- lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
160
- lossd=gamma_loss(1.1),
161
- opt_pp=True, opt_depth=True,
162
- schedule=cosine_schedule, depth_mode='add', exp_depth=False,
163
- lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
164
- init={}, device='cuda', dtype=torch.float32,
165
- matching_conf_thr=4., loss_dust3r_w=0.01,
166
- verbose=True, dbg=()):
167
-
168
- # extrinsic parameters
169
- vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
170
- quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
171
- trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
172
-
173
- # intialize
174
- ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
175
- median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
176
- for img in imgs:
177
- idx = imgs.index(img)
178
- init_values = init.setdefault(img, {})
179
- if verbose and init_values:
180
- print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
181
-
182
- K = init_values.get('intrinsics')
183
- if K is not None:
184
- K = K.detach()
185
- focal = K[:2, :2].diag().mean()
186
- pp = K[:2, 2]
187
- base_focals[idx] = focal
188
- pps[idx] = pp
189
- pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
190
-
191
- depth = init_values.get('depthmap')
192
- if depth is not None:
193
- core_depth[idx] = depth.detach()
194
-
195
- median_depths[idx] = med_depth = core_depth[idx].median()
196
- core_depth[idx] /= med_depth
197
-
198
- cam2w = init_values.get('cam2w')
199
- if cam2w is not None:
200
- rot = cam2w[:3, :3].detach()
201
- cam_center = cam2w[:3, 3].detach()
202
- quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
203
- trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
204
- trans[idx].data[:] = cam_center + rot @ trans_offset
205
- del rot
206
- assert False, 'inverse kinematic chain not yet implemented'
207
-
208
- # intrinsics parameters
209
- pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
210
- diags = imsizes.float().norm(dim=1)
211
- min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
212
- max_focals = 10 * diags
213
- log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
214
- assert len(mst[1]) == len(pps) - 1
215
-
216
- def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
217
- # make intrinsics
218
- focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
219
- pps = torch.stack(pps)
220
- K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
221
- K[:, 0, 0] = K[:, 1, 1] = focals
222
- K[:, 0:2, 2] = pps * imsizes
223
- if trans is None:
224
- return K
225
-
226
- # security! optimization is always trying to crush the scale down
227
- sizes = torch.cat(log_sizes).exp()
228
- global_scaling = 1 / sizes.min()
229
-
230
- # compute distance of camera to focal plane
231
- # tan(fov) = W/2 / focal
232
- z_cameras = sizes * median_depths * focals / base_focals
233
-
234
- # make extrinsic
235
- rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
236
- rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
237
- rel_cam2cam[:, :3, 3] = torch.stack(trans)
238
-
239
- # camera are defined as a kinematic chain
240
- tmp_cam2w = [None] * len(K)
241
- tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
242
- for i, j in mst[1]:
243
- # i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
244
- tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
245
- tmp_cam2w = torch.stack(tmp_cam2w)
246
-
247
- # smart reparameterizaton of cameras
248
- trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
249
- new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
250
- cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
251
- vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
252
-
253
- depthmaps = []
254
- for i in range(len(imgs)):
255
- core_depth_img = core_depth[i]
256
- if exp_depth:
257
- core_depth_img = core_depth_img.exp()
258
- if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
259
- core_depth_img = lora_depth_proj[i] @ core_depth_img
260
- if depth_mode == 'add':
261
- core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
262
- elif depth_mode == 'mul':
263
- core_depth_img = z_cameras[i] * core_depth_img
264
- else:
265
- raise ValueError(f'Bad {depth_mode=}')
266
- depthmaps.append(global_scaling * core_depth_img)
267
-
268
- return K, (inv(cam2w), cam2w), depthmaps
269
-
270
- K = make_K_cam_depth(log_focals, pps, None, None, None, None)
271
- print('init focals =', to_numpy(K[:, 0, 0]))
272
-
273
- # spectral low-rank projection of depthmaps
274
- if lora_depth:
275
- core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
276
- imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
277
- if exp_depth:
278
- core_depth = [d.clip(min=1e-4).log() for d in core_depth]
279
- core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
280
- log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
281
-
282
- # Fetch img slices
283
- _, confs_sum, imgs_slices = corres
284
-
285
- # Define which pairs are fine to use with matching
286
- def matching_check(x): return x.max() > matching_conf_thr
287
- is_matching_ok = {}
288
- for s in imgs_slices:
289
- is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
290
-
291
- # Subsample preds_21
292
- subsamp_preds_21 = {}
293
- for imk, imv in preds_21.items():
294
- subsamp_preds_21[imk] = {}
295
- for im2k, (pred, conf) in preds_21[imk].items():
296
- subpred = pred[::subsample, ::subsample].reshape(-1, 3) # original subsample
297
- subconf = conf[::subsample, ::subsample].ravel() # for both ptmaps and confs
298
- idxs = anchors[imgs.index(im2k)][1]
299
- subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) # anchors subsample
300
-
301
- def loss_dust3r(cam2w, pts3d, pix_loss):
302
- # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
303
- loss = 0.
304
- cf_sum = 0.
305
- for s in imgs_slices:
306
- if not is_matching_ok[s.img1, s.img2]:
307
- # fallback to dust3r regression
308
- tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
309
- tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
310
- cf_sum += tgt_confs.sum()
311
- loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
312
- return loss / cf_sum if cf_sum != 0. else 0.
313
-
314
- def loss_3d(K, w2cam, pts3d, pix_loss):
315
- # For each correspondence, we have two 3D points (one for each image of the pair).
316
- # For each 3D point, we have 2 reproj errors
317
- if any(v.get('freeze') for v in init.values()):
318
- pts3d_1 = []
319
- pts3d_2 = []
320
- confs = []
321
- for s in imgs_slices:
322
- if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
323
- continue
324
- if is_matching_ok[s.img1, s.img2]:
325
- pts3d_1.append(pts3d[s.img1][s.slice1])
326
- pts3d_2.append(pts3d[s.img2][s.slice2])
327
- confs.append(s.confs)
328
- else:
329
- pts3d_1 = [pts3d[s.img1][s.slice1] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
330
- pts3d_2 = [pts3d[s.img2][s.slice2] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
331
- confs = [s.confs for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
332
-
333
- if pts3d_1 != []:
334
- confs = torch.cat(confs)
335
- pts3d_1 = torch.cat(pts3d_1)
336
- pts3d_2 = torch.cat(pts3d_2)
337
- loss = confs @ pix_loss(pts3d_1, pts3d_2)
338
- cf_sum = confs.sum()
339
- else:
340
- loss = 0.
341
- cf_sum = 1.
342
-
343
- return loss / cf_sum
344
-
345
- def loss_2d(K, w2cam, pts3d, pix_loss):
346
- # For each correspondence, we have two 3D points (one for each image of the pair).
347
- # For each 3D point, we have 2 reproj errors
348
- proj_matrix = K @ w2cam[:, :3]
349
- loss = npix = 0
350
- for img1, pix1, confs, cf_sum, imgs_slices in corres2d:
351
- if init[imgs[img1]].get('freeze', 0) >= 1:
352
- continue # no need
353
- pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in imgs_slices if is_matching_ok[img1, img2]]
354
- pix1_filtered = []
355
- confs_filtered = []
356
- curstep = 0
357
- for img2, slice2 in imgs_slices:
358
- if is_matching_ok[img1, img2]:
359
- tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
360
- pix1_filtered.append(pix1[tslice])
361
- confs_filtered.append(confs[tslice])
362
- curstep += slice2.stop - slice2.start
363
- if pts3d_in_img1 != []:
364
- pts3d_in_img1 = torch.cat(pts3d_in_img1)
365
- pix1_filtered = torch.cat(pix1_filtered)
366
- confs_filtered = torch.cat(confs_filtered)
367
- loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
368
- npix += confs_filtered.sum()
369
- return loss / npix if npix != 0 else 0.
370
-
371
- def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
372
- # create optimizer
373
- params = pps + log_focals + quats + trans + log_sizes + core_depth
374
- optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
375
- ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
376
-
377
- with tqdm(total=niter) as bar:
378
- for iter in range(niter or 1):
379
- K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
380
- pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
381
- if niter == 0:
382
- break
383
-
384
- alpha = (iter / niter)
385
- lr = schedule(alpha, lr_base, lr_end)
386
- adjust_learning_rate_by_lr(optimizer, lr)
387
- pix_loss = ploss(1 - alpha)
388
- optimizer.zero_grad()
389
- loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
390
- loss.backward()
391
- optimizer.step()
392
-
393
- # make sure the pose remains well optimizable
394
- for i in range(len(imgs)):
395
- quats[i].data[:] /= quats[i].data.norm()
396
-
397
- loss = float(loss)
398
- if loss != loss:
399
- break # NaN loss
400
- bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
401
- bar.update(1)
402
-
403
- if niter:
404
- print(f'>> final loss = {loss}')
405
- return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
406
- depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
407
-
408
- # at start, don't optimize 3d points
409
- for i, img in enumerate(imgs):
410
- trainable = not (init[img].get('freeze'))
411
- pps[i].requires_grad_(False)
412
- log_focals[i].requires_grad_(False)
413
- quats[i].requires_grad_(trainable)
414
- trans[i].requires_grad_(trainable)
415
- log_sizes[i].requires_grad_(trainable)
416
- core_depth[i].requires_grad_(False)
417
-
418
- res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
419
-
420
- res_fine = None
421
- if niter2:
422
- # now we can optimize 3d points
423
- for i, img in enumerate(imgs):
424
- if init[img].get('freeze', 0) >= 1:
425
- continue
426
- pps[i].requires_grad_(bool(opt_pp))
427
- log_focals[i].requires_grad_(True)
428
- core_depth[i].requires_grad_(opt_depth)
429
-
430
- # refinement with 2d reproj
431
- res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
432
-
433
- return imgs, res_coarse, res_fine
434
-
435
-
436
- @lru_cache
437
- def mask110(device, dtype):
438
- return torch.tensor((1, 1, 0), device=device, dtype=dtype)
439
-
440
-
441
- def proj3d(inv_K, pixels, z):
442
- if pixels.shape[-1] == 2:
443
- pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
444
- return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
445
-
446
-
447
- def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
448
- focals = K[:, 0, 0]
449
- invK = inv(K)
450
- all_pts3d = []
451
- depth_out = []
452
-
453
- for img, (pixels, idxs, offsets) in anchors.items():
454
- # from depthmaps to 3d points
455
- if base_focals is None:
456
- pass
457
- else:
458
- # compensate for focal
459
- # depth + depth * (offset - 1) * base_focal / focal
460
- # = depth * (1 + (offset - 1) * (base_focal / focal))
461
- offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
462
-
463
- pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
464
- if ret_depth:
465
- depth_out.append(pts3d[..., 2]) # before camera rotation
466
-
467
- # rotate to world coordinate
468
- pts3d = geotrf(cam2w[img], pts3d)
469
- all_pts3d.append(pts3d)
470
-
471
- if ret_depth:
472
- return all_pts3d, depth_out
473
- return all_pts3d
474
-
475
-
476
- def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
477
- base_focals = []
478
- anchors = {}
479
- confs = []
480
- for i, canon_path in enumerate(canonical_paths):
481
- (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
482
- confs.append(conf)
483
- base_focals.append(focal)
484
- H, W = conf.shape
485
- pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
486
- idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
487
- anchors[i] = (pixels, idxs[i], offsets[i])
488
-
489
- # densify sparse depthmaps
490
- pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
491
- d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
492
-
493
- return pts3d, depthmaps_out, confs
494
-
495
-
496
- @torch.no_grad()
497
- def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
498
- device='cuda', subsample=8, **matching_kw):
499
- res_paths = {}
500
-
501
- for img1, img2 in tqdm(pairs):
502
- idx1 = hash_md5(img1['instance'])
503
- idx2 = hash_md5(img2['instance'])
504
-
505
- path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
506
- path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
507
- path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
508
- path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
509
-
510
- if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
511
- score, (xy1, xy2, confs) = torch.load(path_corres2)
512
- torch.save((score, (xy2, xy1, confs)), path_corres)
513
-
514
- if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
515
- if model is None:
516
- continue
517
- res = symmetric_inference(model, img1, img2, device=device)
518
- X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
519
- C11, C21, C22, C12 = [r['conf'][0] for r in res]
520
- descs = [r['desc'][0] for r in res]
521
- qonfs = [r[desc_conf][0] for r in res]
522
-
523
- # save
524
- torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
525
- torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
526
-
527
- # perform reciprocal matching
528
- corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
529
-
530
- conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
531
- matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
532
- if cache_path is not None:
533
- torch.save((matching_score, corres), mkdir_for(path_corres))
534
-
535
- res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
536
-
537
- del model
538
- torch.cuda.empty_cache()
539
-
540
- return res_paths, cache_path
541
-
542
-
543
- def symmetric_inference(model, img1, img2, device):
544
- shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
545
- shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
546
- img1 = img1['img'].to(device, non_blocking=True)
547
- img2 = img2['img'].to(device, non_blocking=True)
548
-
549
- # compute encoder only once
550
- feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
551
-
552
- def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
553
- dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
554
- with torch.cuda.amp.autocast(enabled=False):
555
- res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
556
- res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
557
- return res1, res2
558
-
559
- # decoder 1-2
560
- res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
561
- # decoder 2-1
562
- res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
563
-
564
- return (res11, res21, res22, res12)
565
-
566
-
567
- def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
568
- feat11, feat21, feat22, feat12 = feats
569
- qonf11, qonf21, qonf22, qonf12 = qonfs
570
- assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
571
- assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
572
-
573
- if '3d' in ptmap_key:
574
- opt = dict(device='cpu', workers=32)
575
- else:
576
- opt = dict(device=device, dist='dot', block_size=2**13)
577
-
578
- # matching the two pairs
579
- idx1 = []
580
- idx2 = []
581
- qonf1 = []
582
- qonf2 = []
583
- # TODO add non symmetric / pixel_tol options
584
- for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
585
- (feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
586
- nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
587
- nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
588
-
589
- idx1.append(np.r_[nn1to2[0], nn2to1[1]])
590
- idx2.append(np.r_[nn1to2[1], nn2to1[0]])
591
- qonf1.append(QA.ravel()[idx1[-1]])
592
- qonf2.append(QB.ravel()[idx2[-1]])
593
-
594
- # merge corres from opposite pairs
595
- H1, W1 = feat11.shape[:2]
596
- H2, W2 = feat22.shape[:2]
597
- cat = np.concatenate
598
-
599
- xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
600
- corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
601
-
602
- return todevice(corres, device)
603
-
604
-
605
- @torch.no_grad()
606
- def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
607
- cache_path=None, device='cuda', **kw):
608
- canonical_views = {}
609
- pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
610
- canonical_paths = []
611
- preds_21 = {}
612
-
613
- for img in tqdm(imgs):
614
- if cache_path:
615
- cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
616
- canonical_paths.append(cache)
617
- try:
618
- (canon, canon2, cconf), focal = torch.load(cache, map_location=device)
619
- except IOError:
620
- # cache does not exist yet, we create it!
621
- canon = focal = None
622
-
623
- # collect all pred1
624
- n_pairs = sum((img in pair) for pair in tmp_pairs)
625
-
626
- ptmaps11 = None
627
- pixels = {}
628
- n = 0
629
- for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
630
- score = None
631
- if img == img1:
632
- X, C, X2, C2 = torch.load(path1, map_location=device)
633
- score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
634
- pixels[img2] = xy1, confs
635
- if img not in preds_21:
636
- preds_21[img] = {}
637
- preds_21[img][img2] = X2, C2
638
-
639
- if img == img2:
640
- X, C, X2, C2 = torch.load(path2, map_location=device)
641
- score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
642
- pixels[img1] = xy2, confs
643
- if img not in preds_21:
644
- preds_21[img] = {}
645
- preds_21[img][img1] = X2, C2
646
-
647
- if score is not None:
648
- i, j = imgs.index(img1), imgs.index(img2)
649
- # score = score[0]
650
- # score = np.log1p(score[2])
651
- score = score[2]
652
- pairwise_scores[i, j] = score
653
- pairwise_scores[j, i] = score
654
-
655
- if canon is not None:
656
- continue
657
- if ptmaps11 is None:
658
- H, W = C.shape
659
- ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
660
- confs11 = torch.empty((n_pairs, H, W), device=device)
661
-
662
- ptmaps11[n] = X
663
- confs11[n] = C
664
- n += 1
665
-
666
- if canon is None:
667
- canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
668
- del ptmaps11
669
- del confs11
670
-
671
- # compute focals
672
- H, W = canon.shape[:2]
673
- pp = torch.tensor([W / 2, H / 2], device=device)
674
- if focal is None:
675
- focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
676
- if cache:
677
- torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
678
-
679
- # extract depth offsets with correspondences
680
- core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
681
- idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
682
-
683
- canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
684
-
685
- return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
686
-
687
-
688
- def load_corres(path_corres, device, min_conf_thr):
689
- score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
690
- valid = confs > min_conf_thr if min_conf_thr else slice(None)
691
- # valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
692
- # print(f'keeping {valid.sum()} / {len(valid)} correspondences')
693
- return score, (xy1[valid], xy2[valid], confs[valid])
694
-
695
-
696
- PairOfSlices = namedtuple(
697
- 'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
698
-
699
-
700
- def condense_data(imgs, tmp_paths, canonical_views, dtype=torch.float32):
701
- # aggregate all data properly
702
- set_imgs = set(imgs)
703
-
704
- principal_points = []
705
- shapes = []
706
- focals = []
707
- core_depth = []
708
- img_anchors = {}
709
- tmp_pixels = {}
710
-
711
- for idx1, img1 in enumerate(imgs):
712
- # load stuff
713
- pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
714
-
715
- principal_points.append(pp)
716
- shapes.append(shape)
717
- focals.append(focal)
718
- core_depth.append(anchors)
719
-
720
- img_uv1 = []
721
- img_idxs = []
722
- img_offs = []
723
- cur_n = [0]
724
-
725
- for img2, (pixels, match_confs) in pixels_confs.items():
726
- if img2 not in set_imgs:
727
- continue
728
- assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
729
- img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
730
- img_idxs.append(idxs[img2])
731
- img_offs.append(offsets[img2])
732
- cur_n.append(cur_n[-1] + len(pixels))
733
- # store the position of 3d points
734
- tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
735
- img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
736
-
737
- all_confs = []
738
- imgs_slices = []
739
- corres2d = {img: [] for img in range(len(imgs))}
740
-
741
- for img1, img2 in tmp_paths:
742
- try:
743
- pix1, confs1, slice1 = tmp_pixels[img1, img2]
744
- pix2, confs2, slice2 = tmp_pixels[img2, img1]
745
- except KeyError:
746
- continue
747
- img1 = imgs.index(img1)
748
- img2 = imgs.index(img2)
749
- confs = (confs1 * confs2).sqrt()
750
-
751
- # prepare for loss_3d
752
- all_confs.append(confs)
753
- anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
754
- anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
755
- imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
756
- img2, slice2, pix2, anchor_idxs2,
757
- confs, float(confs.sum())))
758
-
759
- # prepare for loss_2d
760
- corres2d[img1].append((pix1, confs, img2, slice2))
761
- corres2d[img2].append((pix2, confs, img1, slice1))
762
-
763
- all_confs = torch.cat(all_confs)
764
- corres = (all_confs, float(all_confs.sum()), imgs_slices)
765
-
766
- def aggreg_matches(img1, list_matches):
767
- pix1, confs, img2, slice2 = zip(*list_matches)
768
- all_pix1 = torch.cat(pix1).to(dtype)
769
- all_confs = torch.cat(confs).to(dtype)
770
- return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
771
- corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
772
-
773
- imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
774
- principal_points = torch.stack(principal_points)
775
- focals = torch.cat(focals)
776
- return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d
777
-
778
-
779
- def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
780
- assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
781
-
782
- # canonical pointmap is just a weighted average
783
- confs11 = confs11.unsqueeze(-1) - 0.999
784
- canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
785
-
786
- canon_depth = ptmaps11[..., 2].unsqueeze(1)
787
- S = slice(subsample // 2, None, subsample)
788
- center_depth = canon_depth[:, :, S, S]
789
- assert (center_depth > 0).all()
790
- stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
791
- stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
792
-
793
- if mode == 'avg-reldepth':
794
- rel_depth = stacked_depth / center_depth
795
- stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
796
- canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
797
-
798
- elif mode == 'avg-angle':
799
- xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
800
- stacked_xy = F.pixel_unshuffle(xy, subsample)
801
- B, _, H, W = stacked_xy.shape
802
- stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
803
- stacked_radius.clip_(min=1e-8)
804
-
805
- stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
806
- avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
807
-
808
- # back to depth
809
- stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
810
-
811
- canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
812
- else:
813
- raise ValueError(f'bad {mode=}')
814
-
815
- confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
816
- return canon, canon2, confs
817
-
818
-
819
- def anchor_depth_offsets(canon_depth, pixels, subsample=8):
820
- device = canon_depth.device
821
-
822
- # create a 2D grid of anchor 3D points
823
- H1, W1 = canon_depth.shape
824
- yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
825
- H2, W2 = yx.shape[1:]
826
- cy, cx = yx.reshape(2, -1)
827
- core_depth = canon_depth[cy, cx]
828
- assert (core_depth > 0).all()
829
-
830
- # slave 3d points (attached to core 3d points)
831
- core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
832
- core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
833
-
834
- for img2, (xy1, _confs) in pixels.items():
835
- px, py = xy1.long().T
836
-
837
- # find nearest anchor == block quantization
838
- core_idx = (py // subsample) * W2 + (px // subsample)
839
- core_idxs[img2] = core_idx.to(device)
840
-
841
- # compute relative depth offsets w.r.t. anchors
842
- ref_z = core_depth[core_idx]
843
- pts_z = canon_depth[py, px]
844
- offset = pts_z / ref_z
845
- core_offs[img2] = offset.detach().to(device)
846
-
847
- return core_idxs, core_offs
848
-
849
-
850
- def spectral_clustering(graph, k=None, normalized_cuts=False):
851
- graph.fill_diagonal_(0)
852
-
853
- # graph laplacian
854
- degrees = graph.sum(dim=-1)
855
- laplacian = torch.diag(degrees) - graph
856
- if normalized_cuts:
857
- i_inv = torch.diag(degrees.sqrt().reciprocal())
858
- laplacian = i_inv @ laplacian @ i_inv
859
-
860
- # compute eigenvectors!
861
- eigval, eigvec = torch.linalg.eigh(laplacian)
862
- return eigval[:k], eigvec[:, :k]
863
-
864
-
865
- def sim_func(p1, p2, gamma):
866
- diff = (p1 - p2).norm(dim=-1)
867
- avg_depth = (p1[:, :, 2] + p2[:, :, 2])
868
- rel_distance = diff / avg_depth
869
- sim = torch.exp(-gamma * rel_distance.square())
870
- return sim
871
-
872
-
873
- def backproj(K, depthmap, subsample):
874
- H, W = depthmap.shape
875
- uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
876
- xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
877
- return xyz
878
-
879
-
880
- def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
881
- normalized_cuts=True, gamma=7, min_norm=5):
882
- try:
883
- if cache_path:
884
- cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
885
- lora_proj = torch.load(cache_path, map_location=K.device)
886
-
887
- except IOError:
888
- # reconstruct 3d points in camera coordinates
889
- xyz = backproj(K, depthmap, subsample)
890
-
891
- # compute all distances
892
- xyz = xyz.reshape(-1, 3)
893
- graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
894
- _, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
895
-
896
- if cache_path:
897
- torch.save(lora_proj.cpu(), mkdir_for(cache_path))
898
-
899
- lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
900
-
901
- # depthmap ~= lora_proj @ coeffs
902
- return coeffs, lora_proj
903
-
904
-
905
- def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
906
- # encode the pointmap
907
- coeffs = torch.linalg.pinv(lora_proj) @ x
908
-
909
- # rectify the norm of basis vector to be ~ equal
910
- if coeffs.ndim == 1:
911
- coeffs = coeffs[:, None]
912
- if global_norm:
913
- lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
914
- elif min_norm:
915
- lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
916
- # can have rounding errors here!
917
- coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
918
-
919
- return lora_proj.detach(), coeffs.detach()
920
-
921
-
922
- @torch.no_grad()
923
- def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
924
- # recover 3d points
925
- core_depth = []
926
- lora_proj = []
927
-
928
- for i, img in enumerate(tqdm(imgs)):
929
- cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
930
- depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
931
- cache_path=cache, **kw)
932
- core_depth.append(depth)
933
- lora_proj.append(proj)
934
-
935
- return core_depth, lora_proj
936
-
937
-
938
- def reproj2d(Trf, pts3d):
939
- res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
940
- clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
941
- uv = res[:, 0:2] / clipped_z
942
- return uv.clip(min=-1000, max=2000)
943
-
944
-
945
- def bfs(tree, start_node):
946
- order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
947
- ranks = np.arange(len(order))
948
- ranks[order] = ranks.copy()
949
- return ranks, predecessors
950
-
951
-
952
- def compute_min_spanning_tree(pws):
953
- sparse_graph = sp.dok_array(pws.shape)
954
- for i, j in pws.nonzero().cpu().tolist():
955
- sparse_graph[i, j] = -float(pws[i, j])
956
- msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
957
-
958
- # now reorder the oriented edges, starting from the central point
959
- ranks1, _ = bfs(msp, 0)
960
- ranks2, _ = bfs(msp, ranks1.argmax())
961
- ranks1, _ = bfs(msp, ranks2.argmax())
962
- # this is the point farther from any leaf
963
- root = np.minimum(ranks1, ranks2).argmax()
964
-
965
- # find the ordered list of edges that describe the tree
966
- order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
967
- order = order[1:] # root not do not have a predecessor
968
- edges = [(predecessors[i], i) for i in order]
969
-
970
- return root, edges
971
-
972
-
973
- def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
974
- viz = SceneViz()
975
-
976
- cc = cam2w[:, :3, 3]
977
- cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
978
- colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
979
-
980
- if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
981
- cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
982
- else:
983
- imgs = shapes_or_imgs
984
- cam_kws = dict(images=imgs, cam_size=cs)
985
- if K is not None:
986
- viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
987
-
988
- if gt_cam2w is not None:
989
- if gt_K is None:
990
- gt_K = K
991
- viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
992
-
993
- if pts3d is not None:
994
- for i, p in enumerate(pts3d):
995
- if not len(p):
996
- continue
997
- if masks is None:
998
- viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
999
- else:
1000
- viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
1001
- viz.show(**kw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/cloud_opt/triangulation.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Matches Triangulation Utils
6
- # --------------------------------------------------------
7
-
8
- import numpy as np
9
- import torch
10
-
11
- # Batched Matches Triangulation
12
- def batched_triangulate(pts2d, # [B, Ncams, Npts, 2]
13
- proj_mats): # [B, Ncams, 3, 4] I@E projection matrix
14
- B, Ncams, Npts, two = pts2d.shape
15
- assert two==2
16
- assert proj_mats.shape == (B, Ncams, 3, 4)
17
- # P - xP
18
- x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
19
- y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
20
- eq = torch.cat([x, y], dim=1).transpose(1, 2) # [B, Npts, 2xNcams, 4]
21
- return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution
22
-
23
- def matches_to_depths(intrinsics, # input camera intrinsics [B, Ncams, 3, 3]
24
- extrinsics, # input camera extrinsics [B, Ncams, 3, 4]
25
- matches, # input correspondences [B, Ncams, Npts, 2]
26
- batchsize=16, # bs for batched processing
27
- min_num_valids_ratio=.3 # at least this ratio of image pairs need to predict a match for a given pixel of img1
28
- ):
29
- B, Nv, H, W, five = matches.shape
30
- min_num_valids = np.floor(Nv*min_num_valids_ratio)
31
- out_aggregated_points, out_depths, out_confs = [], [], []
32
- for b in range(B//batchsize+1): # batched processing
33
- start, stop = b*batchsize,min(B,(b+1)*batchsize)
34
- sub_batch=slice(start,stop)
35
- sub_batchsize = stop-start
36
- if sub_batchsize==0:continue
37
- points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1]
38
- allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) # [BxNv, 2, HxW, 2]
39
-
40
- allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:]
41
- cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] # [B, Nv, 3, 4]
42
- formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) # [BxNv, 2, 3, 4]
43
-
44
- # Triangulate matches to 3D
45
- points_3d_world = batched_triangulate(allpoints, formatted_camPs) # [BxNv, HxW, three]
46
-
47
- # Aggregate pairwise predictions
48
- points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3])
49
- valids = points_3d_world.isfinite()
50
- valids_sum = valids.sum(dim=-1)
51
- validsuni=valids_sum.unique()
52
- assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset"
53
- confs[valids_sum==0] = 0.
54
- points_3d_world = points_3d_world*confs[...,None]
55
-
56
- # Take care of NaNs
57
- normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1)
58
- normalization[normalization <= 1e-5] = 1.
59
- points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1)
60
- points_3d_world[~valids] = 0.
61
- aggregated_points = points_3d_world.sum(dim=1) # weighted average (by confidence value) ignoring nans
62
-
63
- # Reset invalid values to nans, with a min visibility threshold
64
- aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan
65
-
66
- # From 3D to depths
67
- refcamE = extrinsics[sub_batch, 0]
68
- points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) # [B,HxW,3]
69
- depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] # [B,H,W]
70
-
71
- # Cat results
72
- out_aggregated_points.append(aggregated_points.cpu())
73
- out_depths.append(depths.cpu())
74
- out_confs.append(confs.sum(dim=1).cpu())
75
-
76
- out_aggregated_points = torch.cat(out_aggregated_points,dim=0)
77
- out_depths = torch.cat(out_depths,dim=0)
78
- out_confs = torch.cat(out_confs,dim=0)
79
-
80
- return out_aggregated_points, out_depths, out_confs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/cloud_opt/tsdf_optimizer.py DELETED
@@ -1,273 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import numpy as np
4
- from tqdm import tqdm
5
- from matplotlib import pyplot as pl
6
-
7
- import mast3r.utils.path_to_dust3r # noqa
8
- from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
9
- from dust3r.cloud_opt.base_opt import clean_pointcloud
10
-
11
-
12
- class TSDFPostProcess:
13
- """ Optimizes a signed distance-function to improve depthmaps.
14
- """
15
-
16
- def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
17
- self.TSDF_thresh = TSDF_thresh # None -> no TSDF
18
- self.TSDF_batchsize = TSDF_batchsize
19
- self.optimizer = optimizer
20
-
21
- pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
22
- pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
23
- self.pts3d = pts3d
24
- self.depthmaps = depthmaps
25
- self.confs = confs
26
-
27
- def _get_depthmaps(self, TSDF_filtering_thresh=None):
28
- if TSDF_filtering_thresh:
29
- self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
30
- dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
31
- return [d.exp() for d in dms]
32
-
33
- @torch.no_grad()
34
- def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
35
- """
36
- Leverage TSDF to post-process estimated depths
37
- for each pixel, find zero level of TSDF along ray (or closest to 0)
38
- """
39
- print("Post-Processing Depths with TSDF fusion.")
40
- self.TSDF_im_depthmaps = []
41
- alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
42
- ), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
43
- for vi in tqdm(range(self.optimizer.n_imgs)):
44
- dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
45
- minvals = torch.full(dm.shape, 1e20)
46
-
47
- for it in range(niter):
48
- H, W = dm.shape
49
- curthresh = (niter - it) * TSDF_filtering_thresh
50
- dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
51
- curthresh # decreasing search std along with iterations
52
- newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
53
- curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
54
- imshape])[0] # [H,W,Nsamp,3]
55
- # Batched TSDF eval
56
- curproj = curproj.view(-1, 3)
57
- tsdf_vals = []
58
- valids = []
59
- for batch in range(0, len(curproj), self.TSDF_batchsize):
60
- values, valid = self._TSDF_query(
61
- curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
62
- tsdf_vals.append(values)
63
- valids.append(valid)
64
- tsdf_vals = torch.cat(tsdf_vals, dim=0)
65
- valids = torch.cat(valids, dim=0)
66
-
67
- tsdf_vals = tsdf_vals.view([H, W, nsamples])
68
- valids = valids.view([H, W, nsamples])
69
-
70
- # keep depth value that got us the closest to 0
71
- tsdf_vals[~valids] = torch.inf # ignore invalid values
72
- tsdf_vals = tsdf_vals.abs()
73
- mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
74
- # when all samples live on a very flat zone, do nothing
75
- allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
76
- dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
77
-
78
- # Save refined depth map
79
- self.TSDF_im_depthmaps.append(dm.log())
80
-
81
- def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
82
- """
83
- TSDF query call: returns the weighted TSDF value for each query point [N, 3]
84
- """
85
- N, three = qpoints.shape
86
- assert three == 3
87
- qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
88
- # get projection coordinates and depths onto images
89
- coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
90
- ), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
91
- image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
92
- proj_depths = coords_and_depth[..., -1]
93
- # recover depth values after scene optim
94
- pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
95
- # Gather TSDF scores
96
- all_SDF_scores = pred_depths - proj_depths # SDF
97
- unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
98
- # all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
99
- all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
100
- # Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
101
- all_TSDF_weights = (~unseen).float() * valids.float()
102
- if weighted:
103
- all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
104
- # Aggregate all votes, ignoring zeros
105
- TSDF_weights = all_TSDF_weights.sum(dim=0)
106
- valids = TSDF_weights != 0.
107
- TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
108
- TSDF_wsum[valids] /= TSDF_weights[valids]
109
- return TSDF_wsum, valids
110
-
111
- def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
112
- """ Recover depth value for each input pixel coordinate, along with OOB validity mask
113
- """
114
- B, N, two = image_coords.shape
115
- assert B == self.optimizer.n_imgs and two == 2
116
- depths = torch.zeros([B, N], device=image_coords.device)
117
- valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
118
- confs = torch.zeros([B, N], device=image_coords.device)
119
- curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
120
- for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
121
- H, W = depth.shape
122
- valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
123
- H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
124
- imc[~valids[ni]] = 0
125
- depths[ni] = depth[imc[:, 1], imc[:, 0]]
126
- confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
127
- return depths, confs, valids
128
-
129
- def _get_confs_with_normals(self):
130
- outconfs = []
131
- # Confidence basedf on depth gradient
132
-
133
- class Sobel(nn.Module):
134
- def __init__(self):
135
- super().__init__()
136
- self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
137
- Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
138
- Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
139
- G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
140
- G = G.unsqueeze(1)
141
- self.filter.weight = nn.Parameter(G, requires_grad=False)
142
-
143
- def forward(self, img):
144
- x = self.filter(img)
145
- x = torch.mul(x, x)
146
- x = torch.sum(x, dim=1, keepdim=True)
147
- x = torch.sqrt(x)
148
- return x
149
-
150
- grad_op = Sobel().to(self.im_depthmaps[0].device)
151
- for conf, depth in zip(self.im_conf, self.im_depthmaps):
152
- grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
153
- if not 'dbg show':
154
- pl.imshow(grad_confs.cpu())
155
- pl.show()
156
- outconfs.append(conf * grad_confs.to(conf))
157
- return outconfs
158
-
159
- def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
160
- """
161
- Projection operation: from 3D points to 2D coordinates + depths
162
- """
163
- B = pts3d.shape[0]
164
- assert pts3d.shape[0] == cam2worlds.shape[0]
165
- # prepare Extrinsincs
166
- R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
167
- Rinv = R.transpose(-2, -1)
168
- tinv = -Rinv @ t[..., None]
169
-
170
- # prepare intrinsics
171
- intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
172
- if len(focals.shape) == 1:
173
- focals = torch.stack([focals, focals], dim=-1)
174
- intrinsics[:, 0, 0] = focals[:, 0]
175
- intrinsics[:, 1, 1] = focals[:, 1]
176
- intrinsics[:, :2, -1] = pps
177
- # Project
178
- projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
179
- projpts = projpts.transpose(-2, -1) # [B,N,3]
180
- projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
181
- return projpts
182
-
183
- def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
184
- in_focals=None, in_pps=None, in_imshapes=None):
185
- """
186
- Backprojection operation: from image depths to 3D points
187
- """
188
- # Get depths and projection params if not provided
189
- focals = self.optimizer.get_focals() if in_focals is None else in_focals
190
- im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
191
- depth = self._get_depthmaps() if in_depths is None else in_depths
192
- pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
193
- imshapes = self.imshapes if in_imshapes is None else in_imshapes
194
- def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
195
- dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
196
-
197
- def autoprocess(x):
198
- x = x[0]
199
- return x.transpose(-2, -1) if len(x.shape) == 4 else x
200
- return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
201
-
202
- def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
203
- """
204
- Projection operation: from 3D points to 2D coordinates + depths
205
- """
206
- B = pts3d.shape[0]
207
- assert pts3d.shape[0] == cam2worlds.shape[0]
208
- # prepare Extrinsincs
209
- R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
210
- Rinv = R.transpose(-2, -1)
211
- tinv = -Rinv @ t[..., None]
212
-
213
- # prepare intrinsics
214
- intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
215
- if len(focals.shape) == 1:
216
- focals = torch.stack([focals, focals], dim=-1)
217
- intrinsics[:, 0, 0] = focals[:, 0]
218
- intrinsics[:, 1, 1] = focals[:, 1]
219
- intrinsics[:, :2, -1] = pps
220
- # Project
221
- projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
222
- projpts = projpts.transpose(-2, -1) # [B,N,3]
223
- projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
224
- return projpts
225
-
226
- def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
227
- """
228
- Backprojection operation: from image depths to 3D points
229
- """
230
- # Get depths and projection params if not provided
231
- focals = self.optimizer.get_focals() if in_focals is None else in_focals
232
- im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
233
- depth = self._get_depthmaps() if in_depths is None else in_depths
234
- pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
235
- imshapes = self.imshapes if in_imshapes is None else in_imshapes
236
-
237
- def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
238
-
239
- dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
240
-
241
- def autoprocess(x):
242
- x = x[0]
243
- H, W, three = x.shape[:3]
244
- return x.transpose(-2, -1) if len(x.shape) == 4 else x
245
- return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
246
-
247
- def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
248
- """
249
- return 3D points (possibly filtering depths with TSDF)
250
- """
251
- return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
252
-
253
- def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
254
- # Setup inner variables
255
- self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
256
- self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
257
- self.im_conf = confs
258
-
259
- if self.TSDF_thresh > 0.:
260
- # Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
261
- self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
262
- depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
263
- # Turn them into 3D points
264
- pts3d = self._backproj_pts3d(in_depths=depthmaps)
265
- depthmaps = [dd.flatten() for dd in depthmaps]
266
- pts3d = [pp.view(-1, 3) for pp in pts3d]
267
- return pts3d, depthmaps
268
-
269
- def get_dense_pts3d(self, clean_depth=True):
270
- if clean_depth:
271
- confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w),
272
- self.depthmaps, self.pts3d)
273
- return self.pts3d, self.depthmaps, confs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/cloud_opt/utils/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/cloud_opt/utils/losses.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # losses for sparse ga
6
- # --------------------------------------------------------
7
- import torch
8
- import numpy as np
9
-
10
-
11
- def l05_loss(x, y):
12
- return torch.linalg.norm(x - y, dim=-1).sqrt()
13
-
14
-
15
- def l1_loss(x, y):
16
- return torch.linalg.norm(x - y, dim=-1)
17
-
18
-
19
- def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
20
- if offset is None:
21
- if gamma == 1:
22
- return l1_loss
23
- # d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
24
- offset = (1 / gamma)**(1 / (gamma - 1))
25
-
26
- def loss_func(x, y):
27
- return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
28
- return loss_func
29
-
30
-
31
- def meta_gamma_loss():
32
- return lambda alpha: gamma_loss(alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/cloud_opt/utils/schedules.py DELETED
@@ -1,17 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # lr schedules for sparse ga
6
- # --------------------------------------------------------
7
- import numpy as np
8
-
9
-
10
- def linear_schedule(alpha, lr_base, lr_end=0):
11
- lr = (1 - alpha) * lr_base + alpha * lr_end
12
- return lr
13
-
14
-
15
- def cosine_schedule(alpha, lr_base, lr_end=0):
16
- lr = lr_end + (lr_base - lr_end) * (1 + np.cos(alpha * np.pi)) / 2
17
- return lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/colmap/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/colmap/database.py DELETED
@@ -1,383 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # MASt3R to colmap export functions
6
- # --------------------------------------------------------
7
- import os
8
- import torch
9
- import copy
10
- import numpy as np
11
- import torchvision
12
- import numpy as np
13
- from tqdm import tqdm
14
- from scipy.cluster.hierarchy import DisjointSet
15
- from scipy.spatial.transform import Rotation as R
16
-
17
- from mast3r.utils.misc import hash_md5
18
-
19
- from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns
20
-
21
- import mast3r.utils.path_to_dust3r # noqa
22
- from dust3r.utils.geometry import find_reciprocal_matches, xy_grid # noqa
23
-
24
-
25
- def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz):
26
- if viz:
27
- from matplotlib import pyplot as pl
28
-
29
- image_mean = torch.as_tensor(
30
- [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
31
- image_std = torch.as_tensor(
32
- [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
33
- rgb0 = img0['img'] * image_std + image_mean
34
- rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0])
35
- rgb0 = np.array(rgb0)
36
-
37
- rgb1 = img1['img'] * image_std + image_mean
38
- rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0])
39
- rgb1 = np.array(rgb1)
40
-
41
- imgs = [rgb0, rgb1]
42
- # visualize a few matches
43
- n_viz = 100
44
- num_matches = matches_im0.shape[0]
45
- match_idx_to_viz = np.round(np.linspace(
46
- 0, num_matches - 1, n_viz)).astype(int)
47
- viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
48
-
49
- H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
50
- rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)),
51
- (0, 0), (0, 0)), 'constant', constant_values=0)
52
- rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)),
53
- (0, 0), (0, 0)), 'constant', constant_values=0)
54
- img = np.concatenate((rgb0, rgb1), axis=1)
55
- pl.figure()
56
- pl.imshow(img)
57
- cmap = pl.get_cmap('jet')
58
- for ii in range(n_viz):
59
- (x0, y0), (x1,
60
- y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T
61
- pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii /
62
- (n_viz - 1)), scalex=False, scaley=False)
63
- pl.show(block=True)
64
-
65
- matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)]
66
- imgs = [img0, img1]
67
- imidx0 = img0['idx']
68
- imidx1 = img1['idx']
69
- ravel_matches = []
70
- for j in range(2):
71
- H, W = imgs[j]['true_shape'][0]
72
- with np.errstate(invalid='ignore'):
73
- qx, qy = matches[j].round().astype(np.int32).T
74
- ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
75
- ravel_matches.append(ravel_matches_j)
76
- imidxj = imgs[j]['idx']
77
- for m in ravel_matches_j:
78
- if m not in im_keypoints[imidxj]:
79
- im_keypoints[imidxj][m] = 0
80
- im_keypoints[imidxj][m] += 1
81
- imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid'])
82
- imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid'])
83
- if imid0 > imid1:
84
- colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1)
85
- imid0, imid1 = imid1, imid0
86
- imidx0, imidx1 = imidx1, imidx0
87
- else:
88
- colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1)
89
- colmap_matches = np.unique(colmap_matches, axis=0)
90
- return imidx0, imidx1, colmap_matches
91
-
92
-
93
- def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr,
94
- is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'):
95
- im_matches = {}
96
- for i in range(len(pred1['pts3d'])):
97
- imidx0 = pairs[i][0]['idx']
98
- imidx1 = pairs[i][1]['idx']
99
- if 'desc' in pred1: # mast3r
100
- descs = [pred1['desc'][i], pred2['desc'][i]]
101
- confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]]
102
- desc_dim = descs[0].shape[-1]
103
-
104
- if is_sparse:
105
- corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1],
106
- device=device, subsample=subsample, pixel_tol=pixel_tol)
107
- conf = corres[2]
108
- mask = conf >= conf_thr
109
- matches_im0 = corres[0][mask].cpu().numpy()
110
- matches_im1 = corres[1][mask].cpu().numpy()
111
- else:
112
- confidence_masks = [confidences[0] >=
113
- conf_thr, confidences[1] >= conf_thr]
114
- pts2d_list, desc_list = [], []
115
- for j in range(2):
116
- conf_j = confidence_masks[j].cpu().numpy().flatten()
117
- true_shape_j = pairs[i][j]['true_shape'][0]
118
- pts2d_j = xy_grid(
119
- true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
120
- desc_j = descs[j].detach().cpu(
121
- ).numpy().reshape(-1, desc_dim)[conf_j]
122
- pts2d_list.append(pts2d_j)
123
- desc_list.append(desc_j)
124
- if len(desc_list[0]) == 0 or len(desc_list[1]) == 0:
125
- continue
126
-
127
- nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1],
128
- device=device, dist='dot', block_size=2**13)
129
- reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0)))
130
-
131
- matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0]
132
- matches_im0 = pts2d_list[0][reciprocal_in_P0]
133
- else:
134
- pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]]
135
- confidences = [pred1['conf'][i], pred2['conf'][i]]
136
-
137
- if is_sparse:
138
- corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1],
139
- device=device, subsample=subsample, pixel_tol=pixel_tol,
140
- ptmap_key='3d')
141
- conf = corres[2]
142
- mask = conf >= conf_thr
143
- matches_im0 = corres[0][mask].cpu().numpy()
144
- matches_im1 = corres[1][mask].cpu().numpy()
145
- else:
146
- confidence_masks = [confidences[0] >=
147
- conf_thr, confidences[1] >= conf_thr]
148
- # find 2D-2D matches between the two images
149
- pts2d_list, pts3d_list = [], []
150
- for j in range(2):
151
- conf_j = confidence_masks[j].cpu().numpy().flatten()
152
- true_shape_j = pairs[i][j]['true_shape'][0]
153
- pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
154
- pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j]
155
- pts2d_list.append(pts2d_j)
156
- pts3d_list.append(pts3d_j)
157
-
158
- PQ, PM = pts3d_list[0], pts3d_list[1]
159
- if len(PQ) == 0 or len(PM) == 0:
160
- continue
161
- reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(
162
- PQ, PM)
163
-
164
- matches_im1 = pts2d_list[1][reciprocal_in_PM]
165
- matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM]
166
-
167
- if len(matches_im0) == 0:
168
- continue
169
- imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
170
- image_to_colmap, im_keypoints,
171
- matches_im0, matches_im1, viz)
172
- im_matches[(imidx0, imidx1)] = colmap_matches
173
- return im_matches
174
-
175
-
176
- def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample,
177
- image_to_colmap, im_keypoints, conf_thr,
178
- viz=False, device='cuda'):
179
- im_matches = {}
180
- for i in range(len(pairs)):
181
- imidx0 = pairs[i][0]['idx']
182
- imidx1 = pairs[i][1]['idx']
183
-
184
- corres_idx1 = hash_md5(pairs[i][0]['instance'])
185
- corres_idx2 = hash_md5(pairs[i][1]['instance'])
186
-
187
- path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth'
188
- if os.path.isfile(path_corres):
189
- score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
190
- else:
191
- path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth'
192
- score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device)
193
- mask = confs >= conf_thr
194
- matches_im0 = xy1[mask].cpu().numpy()
195
- matches_im1 = xy2[mask].cpu().numpy()
196
-
197
- if len(matches_im0) == 0:
198
- continue
199
- imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
200
- image_to_colmap, im_keypoints,
201
- matches_im0, matches_im1, viz)
202
- im_matches[(imidx0, imidx1)] = colmap_matches
203
- return im_matches
204
-
205
-
206
- def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model):
207
- # add cameras/images to the db
208
- # with the output of ga as prior
209
- image_to_colmap = {}
210
- im_keypoints = {}
211
- for idx in range(len(image_paths)):
212
- im_keypoints[idx] = {}
213
- H, W = images[idx]["orig_shape"]
214
- if focals is None:
215
- focal_x = focal_y = 1.2 * max(W, H)
216
- prior_focal_length = False
217
- cx = W / 2.0
218
- cy = H / 2.0
219
- elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2:
220
- # intrinsics
221
- focal_x = focals[idx][0, 0]
222
- focal_y = focals[idx][1, 1]
223
- cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0]
224
- cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1]
225
- prior_focal_length = True
226
- else:
227
- focal_x = focal_y = float(focals[idx])
228
- prior_focal_length = True
229
- cx = W / 2.0
230
- cy = H / 2.0
231
- focal_x = focal_x * images[idx]["to_orig"][0, 0]
232
- focal_y = focal_y * images[idx]["to_orig"][1, 1]
233
-
234
- if camera_model == "SIMPLE_PINHOLE":
235
- model_id = 0
236
- focal = (focal_x + focal_y) / 2.0
237
- params = np.asarray([focal, cx, cy], np.float64)
238
- elif camera_model == "PINHOLE":
239
- model_id = 1
240
- params = np.asarray([focal_x, focal_y, cx, cy], np.float64)
241
- elif camera_model == "SIMPLE_RADIAL":
242
- model_id = 2
243
- focal = (focal_x + focal_y) / 2.0
244
- params = np.asarray([focal, cx, cy, 0.0], np.float64)
245
- elif camera_model == "OPENCV":
246
- model_id = 4
247
- params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64)
248
- else:
249
- raise ValueError(f"invalid camera model {camera_model}")
250
-
251
- H, W = int(H), int(W)
252
- # OPENCV camera model
253
- camid = db.add_camera(
254
- model_id, W, H, params, prior_focal_length=prior_focal_length)
255
- if ga_world_to_cam is None:
256
- prior_t = np.zeros(3)
257
- prior_q = np.zeros(4)
258
- else:
259
- q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat()
260
- prior_t = ga_world_to_cam[idx][:3, 3]
261
- prior_q = np.array([q[-1], q[0], q[1], q[2]])
262
- imid = db.add_image(
263
- image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t)
264
- image_to_colmap[idx] = {
265
- 'colmap_imid': imid,
266
- 'colmap_camid': camid
267
- }
268
- return image_to_colmap, im_keypoints
269
-
270
-
271
- def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification):
272
- colmap_image_pairs = []
273
- # 2D-2D are quite dense
274
- # we want to remove the very small tracks
275
- # and export only kpt for which we have values
276
- # build tracks
277
- print("building tracks")
278
- keypoints_to_track_id = {}
279
- track_id_to_kpt_list = []
280
- to_merge = []
281
- for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()):
282
- if imidx0 not in keypoints_to_track_id:
283
- keypoints_to_track_id[imidx0] = {}
284
- if imidx1 not in keypoints_to_track_id:
285
- keypoints_to_track_id[imidx1] = {}
286
-
287
- for m in colmap_matches:
288
- if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]:
289
- # new pair of kpts never seen before
290
- track_idx = len(track_id_to_kpt_list)
291
- keypoints_to_track_id[imidx0][m[0]] = track_idx
292
- keypoints_to_track_id[imidx1][m[1]] = track_idx
293
- track_id_to_kpt_list.append(
294
- [(imidx0, m[0]), (imidx1, m[1])])
295
- elif m[1] not in keypoints_to_track_id[imidx1]:
296
- # 0 has a track, not 1
297
- track_idx = keypoints_to_track_id[imidx0][m[0]]
298
- keypoints_to_track_id[imidx1][m[1]] = track_idx
299
- track_id_to_kpt_list[track_idx].append((imidx1, m[1]))
300
- elif m[0] not in keypoints_to_track_id[imidx0]:
301
- # 1 has a track, not 0
302
- track_idx = keypoints_to_track_id[imidx1][m[1]]
303
- keypoints_to_track_id[imidx0][m[0]] = track_idx
304
- track_id_to_kpt_list[track_idx].append((imidx0, m[0]))
305
- else:
306
- # both have tracks, merge them
307
- track_idx0 = keypoints_to_track_id[imidx0][m[0]]
308
- track_idx1 = keypoints_to_track_id[imidx1][m[1]]
309
- if track_idx0 != track_idx1:
310
- # let's deal with them later
311
- to_merge.append((track_idx0, track_idx1))
312
-
313
- # regroup merge targets
314
- print("merging tracks")
315
- unique = np.unique(to_merge)
316
- tree = DisjointSet(unique)
317
- for track_idx0, track_idx1 in tqdm(to_merge):
318
- tree.merge(track_idx0, track_idx1)
319
-
320
- subsets = tree.subsets()
321
- print("applying merge")
322
- for setvals in tqdm(subsets):
323
- new_trackid = len(track_id_to_kpt_list)
324
- kpt_list = []
325
- for track_idx in setvals:
326
- kpt_list.extend(track_id_to_kpt_list[track_idx])
327
- for imidx, kpid in track_id_to_kpt_list[track_idx]:
328
- keypoints_to_track_id[imidx][kpid] = new_trackid
329
- track_id_to_kpt_list.append(kpt_list)
330
-
331
- # binc = np.bincount([len(v) for v in track_id_to_kpt_list])
332
- # nonzero = np.nonzero(binc)
333
- # nonzerobinc = binc[nonzero[0]]
334
- # print(nonzero[0].tolist())
335
- # print(nonzerobinc)
336
- num_valid_tracks = sum(
337
- [1 for v in track_id_to_kpt_list if len(v) >= min_len_track])
338
-
339
- keypoints_to_idx = {}
340
- print(f"squashing keypoints - {num_valid_tracks} valid tracks")
341
- for imidx, keypoints_imid in tqdm(im_keypoints.items()):
342
- imid = image_to_colmap[imidx]['colmap_imid']
343
- keypoints_kept = []
344
- keypoints_to_idx[imidx] = {}
345
- for kp in keypoints_imid.keys():
346
- if kp not in keypoints_to_track_id[imidx]:
347
- continue
348
- track_idx = keypoints_to_track_id[imidx][kp]
349
- track_length = len(track_id_to_kpt_list[track_idx])
350
- if track_length < min_len_track:
351
- continue
352
- keypoints_to_idx[imidx][kp] = len(keypoints_kept)
353
- keypoints_kept.append(kp)
354
- if len(keypoints_kept) == 0:
355
- continue
356
- keypoints_kept = np.array(keypoints_kept)
357
- keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[
358
- 0].base[:, ::-1].copy().astype(np.float32)
359
- # rescale coordinates
360
- keypoints_kept[:, 0] += 0.5
361
- keypoints_kept[:, 1] += 0.5
362
- keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True)
363
-
364
- H, W = images[imidx]['orig_shape']
365
- keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01)
366
- keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01)
367
-
368
- db.add_keypoints(imid, keypoints_kept)
369
-
370
- print("exporting im_matches")
371
- for (imidx0, imidx1), colmap_matches in im_matches.items():
372
- imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid']
373
- assert imid0 < imid1
374
- final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]]
375
- for m in colmap_matches
376
- if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]])
377
- if len(final_matches) > 0:
378
- colmap_image_pairs.append(
379
- (images[imidx0]['instance'], images[imidx1]['instance']))
380
- db.add_matches(imid0, imid1, final_matches)
381
- if skip_geometric_verification:
382
- db.add_two_view_geometry(imid0, imid1, final_matches)
383
- return colmap_image_pairs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/datasets/__init__.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
- from .base.mast3r_base_stereo_view_dataset import MASt3RBaseStereoViewDataset
5
-
6
- import mast3r.utils.path_to_dust3r # noqa
7
- from dust3r.datasets.arkitscenes import ARKitScenes as DUSt3R_ARKitScenes # noqa
8
- from dust3r.datasets.blendedmvs import BlendedMVS as DUSt3R_BlendedMVS # noqa
9
- from dust3r.datasets.co3d import Co3d as DUSt3R_Co3d # noqa
10
- from dust3r.datasets.megadepth import MegaDepth as DUSt3R_MegaDepth # noqa
11
- from dust3r.datasets.scannetpp import ScanNetpp as DUSt3R_ScanNetpp # noqa
12
- from dust3r.datasets.staticthings3d import StaticThings3D as DUSt3R_StaticThings3D # noqa
13
- from dust3r.datasets.waymo import Waymo as DUSt3R_Waymo # noqa
14
- from dust3r.datasets.wildrgbd import WildRGBD as DUSt3R_WildRGBD # noqa
15
-
16
-
17
- class ARKitScenes(DUSt3R_ARKitScenes, MASt3RBaseStereoViewDataset):
18
- def __init__(self, *args, split, ROOT, **kwargs):
19
- super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
20
- self.is_metric_scale = True
21
-
22
-
23
- class BlendedMVS(DUSt3R_BlendedMVS, MASt3RBaseStereoViewDataset):
24
- def __init__(self, *args, ROOT, split=None, **kwargs):
25
- super().__init__(*args, ROOT=ROOT, split=split, **kwargs)
26
- self.is_metric_scale = False
27
-
28
-
29
- class Co3d(DUSt3R_Co3d, MASt3RBaseStereoViewDataset):
30
- def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
31
- super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
32
- self.is_metric_scale = False
33
-
34
-
35
- class MegaDepth(DUSt3R_MegaDepth, MASt3RBaseStereoViewDataset):
36
- def __init__(self, *args, split, ROOT, **kwargs):
37
- super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
38
- self.is_metric_scale = False
39
-
40
-
41
- class ScanNetpp(DUSt3R_ScanNetpp, MASt3RBaseStereoViewDataset):
42
- def __init__(self, *args, ROOT, **kwargs):
43
- super().__init__(*args, ROOT=ROOT, **kwargs)
44
- self.is_metric_scale = True
45
-
46
-
47
- class StaticThings3D(DUSt3R_StaticThings3D, MASt3RBaseStereoViewDataset):
48
- def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):
49
- super().__init__(ROOT, *args, mask_bg=mask_bg, **kwargs)
50
- self.is_metric_scale = False
51
-
52
-
53
- class Waymo(DUSt3R_Waymo, MASt3RBaseStereoViewDataset):
54
- def __init__(self, *args, ROOT, **kwargs):
55
- super().__init__(*args, ROOT=ROOT, **kwargs)
56
- self.is_metric_scale = True
57
-
58
-
59
- class WildRGBD(DUSt3R_WildRGBD, MASt3RBaseStereoViewDataset):
60
- def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
61
- super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
62
- self.is_metric_scale = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/datasets/base/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/datasets/base/mast3r_base_stereo_view_dataset.py DELETED
@@ -1,355 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # base class for implementing datasets
6
- # --------------------------------------------------------
7
- import PIL.Image
8
- import PIL.Image as Image
9
- import numpy as np
10
- import torch
11
- import copy
12
-
13
- from mast3r.datasets.utils.cropping import (extract_correspondences_from_pts3d,
14
- gen_random_crops, in2d_rect, crop_to_homography)
15
-
16
- import mast3r.utils.path_to_dust3r # noqa
17
- from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset, view_name, is_good_type # noqa
18
- from dust3r.datasets.utils.transforms import ImgNorm
19
- from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, depthmap_to_camera_coordinates
20
- import dust3r.datasets.utils.cropping as cropping
21
-
22
-
23
- class MASt3RBaseStereoViewDataset(BaseStereoViewDataset):
24
- def __init__(self, *, # only keyword arguments
25
- split=None,
26
- resolution=None, # square_size or (width, height) or list of [(width,height), ...]
27
- transform=ImgNorm,
28
- aug_crop=False,
29
- aug_swap=False,
30
- aug_monocular=False,
31
- aug_portrait_or_landscape=True, # automatic choice between landscape/portrait when possible
32
- aug_rot90=False,
33
- n_corres=0,
34
- nneg=0,
35
- n_tentative_crops=4,
36
- seed=None):
37
- super().__init__(split=split, resolution=resolution, transform=transform, aug_crop=aug_crop, seed=seed)
38
- self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
39
-
40
- self.aug_swap = aug_swap
41
- self.aug_monocular = aug_monocular
42
- self.aug_portrait_or_landscape = aug_portrait_or_landscape
43
- self.aug_rot90 = aug_rot90
44
-
45
- self.n_corres = n_corres
46
- self.nneg = nneg
47
- assert self.n_corres == 'all' or isinstance(self.n_corres, int) or (isinstance(self.n_corres, list) and len(
48
- self.n_corres) == self.num_views), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
49
- assert self.nneg == 0 or self.n_corres != 'all'
50
- self.n_tentative_crops = n_tentative_crops
51
-
52
- def _swap_view_aug(self, views):
53
- if self._rng.random() < 0.5:
54
- views.reverse()
55
-
56
- def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
57
- """ This function:
58
- - first downsizes the image with LANCZOS inteprolation,
59
- which is better than bilinear interpolation in
60
- """
61
- if not isinstance(image, PIL.Image.Image):
62
- image = PIL.Image.fromarray(image)
63
-
64
- # transpose the resolution if necessary
65
- W, H = image.size # new size
66
- assert resolution[0] >= resolution[1]
67
- if H > 1.1 * W:
68
- # image is portrait mode
69
- resolution = resolution[::-1]
70
- elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
71
- # image is square, so we chose (portrait, landscape) randomly
72
- if rng.integers(2) and self.aug_portrait_or_landscape:
73
- resolution = resolution[::-1]
74
-
75
- # high-quality Lanczos down-scaling
76
- target_resolution = np.array(resolution)
77
- image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
78
-
79
- # actual cropping (if necessary) with bilinear interpolation
80
- offset_factor = 0.5
81
- intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
82
- crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
83
- image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
84
-
85
- return image, depthmap, intrinsics2
86
-
87
- def generate_crops_from_pair(self, view1, view2, resolution, aug_crop_arg, n_crops=4, rng=np.random):
88
- views = [view1, view2]
89
-
90
- if aug_crop_arg is False:
91
- # compatibility
92
- for i in range(2):
93
- view = views[i]
94
- view['img'], view['depthmap'], view['camera_intrinsics'] = self._crop_resize_if_necessary(view['img'],
95
- view['depthmap'],
96
- view['camera_intrinsics'],
97
- resolution,
98
- rng=rng)
99
- view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
100
- view['camera_intrinsics'],
101
- view['camera_pose'])
102
- return
103
-
104
- # extract correspondences
105
- corres = extract_correspondences_from_pts3d(*views, target_n_corres=None, rng=rng)
106
-
107
- # generate 4 random crops in each view
108
- view_crops = []
109
- crops_resolution = []
110
- corres_msks = []
111
- for i in range(2):
112
-
113
- if aug_crop_arg == 'auto':
114
- S = min(views[i]['img'].size)
115
- R = min(resolution)
116
- aug_crop = S * (S - R) // R
117
- aug_crop = max(.1 * S, aug_crop) # for cropping: augment scale of at least 10%, and more if possible
118
- else:
119
- aug_crop = aug_crop_arg
120
-
121
- # tranpose the target resolution if necessary
122
- assert resolution[0] >= resolution[1]
123
- W, H = imsize = views[i]['img'].size
124
- crop_resolution = resolution
125
- if H > 1.1 * W:
126
- # image is portrait mode
127
- crop_resolution = resolution[::-1]
128
- elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
129
- # image is square, so we chose (portrait, landscape) randomly
130
- if rng.integers(2):
131
- crop_resolution = resolution[::-1]
132
-
133
- crops = gen_random_crops(imsize, n_crops, crop_resolution, aug_crop=aug_crop, rng=rng)
134
- view_crops.append(crops)
135
- crops_resolution.append(crop_resolution)
136
-
137
- # compute correspondences
138
- corres_msks.append(in2d_rect(corres[i], crops))
139
-
140
- # compute IoU for each
141
- intersection = np.float32(corres_msks[0]).T @ np.float32(corres_msks[1])
142
- # select best pair of crops
143
- best = np.unravel_index(intersection.argmax(), (n_crops, n_crops))
144
- crops = [view_crops[i][c] for i, c in enumerate(best)]
145
-
146
- # crop with the homography
147
- for i in range(2):
148
- view = views[i]
149
- imsize, K_new, R, H = crop_to_homography(view['camera_intrinsics'], crops[i], crops_resolution[i])
150
- # imsize, K_new, H = upscale_homography(imsize, resolution, K_new, H)
151
-
152
- # update camera params
153
- K_old = view['camera_intrinsics']
154
- view['camera_intrinsics'] = K_new
155
- view['camera_pose'] = view['camera_pose'].copy()
156
- view['camera_pose'][:3, :3] = view['camera_pose'][:3, :3] @ R
157
-
158
- # apply homography to image and depthmap
159
- homo8 = (H / H[2, 2]).ravel().tolist()[:8]
160
- view['img'] = view['img'].transform(imsize, Image.Transform.PERSPECTIVE,
161
- homo8,
162
- resample=Image.Resampling.BICUBIC)
163
-
164
- depthmap2 = depthmap_to_camera_coordinates(view['depthmap'], K_old)[0] @ R[:, 2]
165
- view['depthmap'] = np.array(Image.fromarray(depthmap2).transform(
166
- imsize, Image.Transform.PERSPECTIVE, homo8))
167
-
168
- if 'track_labels' in view:
169
- # convert from uint64 --> uint32, because PIL.Image cannot handle uint64
170
- mapping, track_labels = np.unique(view['track_labels'], return_inverse=True)
171
- track_labels = track_labels.astype(np.uint32).reshape(view['track_labels'].shape)
172
-
173
- # homography transformation
174
- res = np.array(Image.fromarray(track_labels).transform(imsize, Image.Transform.PERSPECTIVE, homo8))
175
- view['track_labels'] = mapping[res] # mapping back to uint64
176
-
177
- # recompute 3d points from scratch
178
- view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
179
- view['camera_intrinsics'],
180
- view['camera_pose'])
181
-
182
- def __getitem__(self, idx):
183
- if isinstance(idx, tuple):
184
- # the idx is specifying the aspect-ratio
185
- idx, ar_idx = idx
186
- else:
187
- assert len(self._resolutions) == 1
188
- ar_idx = 0
189
-
190
- # set-up the rng
191
- if self.seed: # reseed for each __getitem__
192
- self._rng = np.random.default_rng(seed=self.seed + idx)
193
- elif not hasattr(self, '_rng'):
194
- seed = torch.initial_seed() # this is different for each dataloader process
195
- self._rng = np.random.default_rng(seed=seed)
196
-
197
- # over-loaded code
198
- resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
199
- views = self._get_views(idx, resolution, self._rng)
200
- assert len(views) == self.num_views
201
-
202
- for v, view in enumerate(views):
203
- assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
204
- view['idx'] = (idx, ar_idx, v)
205
- view['is_metric_scale'] = self.is_metric_scale
206
-
207
- assert 'camera_intrinsics' in view
208
- if 'camera_pose' not in view:
209
- view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
210
- else:
211
- assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
212
- assert 'pts3d' not in view
213
- assert 'valid_mask' not in view
214
- assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
215
-
216
- pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
217
-
218
- view['pts3d'] = pts3d
219
- view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
220
-
221
- self.generate_crops_from_pair(views[0], views[1], resolution=resolution,
222
- aug_crop_arg=self.aug_crop,
223
- n_crops=self.n_tentative_crops,
224
- rng=self._rng)
225
- for v, view in enumerate(views):
226
- # encode the image
227
- width, height = view['img'].size
228
- view['true_shape'] = np.int32((height, width))
229
- view['img'] = self.transform(view['img'])
230
- # Pixels for which depth is fundamentally undefined
231
- view['sky_mask'] = (view['depthmap'] < 0)
232
-
233
- if self.aug_swap:
234
- self._swap_view_aug(views)
235
-
236
- if self.aug_monocular:
237
- if self._rng.random() < self.aug_monocular:
238
- views = [copy.deepcopy(views[0]) for _ in range(len(views))]
239
-
240
- # automatic extraction of correspondences from pts3d + pose
241
- if self.n_corres > 0 and ('corres' not in view):
242
- corres1, corres2, valid = extract_correspondences_from_pts3d(*views, self.n_corres,
243
- self._rng, nneg=self.nneg)
244
- views[0]['corres'] = corres1
245
- views[1]['corres'] = corres2
246
- views[0]['valid_corres'] = valid
247
- views[1]['valid_corres'] = valid
248
-
249
- if self.aug_rot90 is False:
250
- pass
251
- elif self.aug_rot90 == 'same':
252
- rotate_90(views, k=self._rng.choice(4))
253
- elif self.aug_rot90 == 'diff':
254
- rotate_90(views[:1], k=self._rng.choice(4))
255
- rotate_90(views[1:], k=self._rng.choice(4))
256
- else:
257
- raise ValueError(f'Bad value for {self.aug_rot90=}')
258
-
259
- # check data-types metric_scale
260
- for v, view in enumerate(views):
261
- if 'corres' not in view:
262
- view['corres'] = np.full((self.n_corres, 2), np.nan, dtype=np.float32)
263
-
264
- # check all datatypes
265
- for key, val in view.items():
266
- res, err_msg = is_good_type(key, val)
267
- assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
268
- K = view['camera_intrinsics']
269
-
270
- # check shapes
271
- assert view['depthmap'].shape == view['img'].shape[1:]
272
- assert view['depthmap'].shape == view['pts3d'].shape[:2]
273
- assert view['depthmap'].shape == view['valid_mask'].shape
274
-
275
- # last thing done!
276
- for view in views:
277
- # transpose to make sure all views are the same size
278
- transpose_to_landscape(view)
279
- # this allows to check whether the RNG is is the same state each time
280
- view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
281
-
282
- return views
283
-
284
-
285
- def transpose_to_landscape(view, revert=False):
286
- height, width = view['true_shape']
287
-
288
- if width < height:
289
- if revert:
290
- height, width = width, height
291
-
292
- # rectify portrait to landscape
293
- assert view['img'].shape == (3, height, width)
294
- view['img'] = view['img'].swapaxes(1, 2)
295
-
296
- assert view['valid_mask'].shape == (height, width)
297
- view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
298
-
299
- assert view['sky_mask'].shape == (height, width)
300
- view['sky_mask'] = view['sky_mask'].swapaxes(0, 1)
301
-
302
- assert view['depthmap'].shape == (height, width)
303
- view['depthmap'] = view['depthmap'].swapaxes(0, 1)
304
-
305
- assert view['pts3d'].shape == (height, width, 3)
306
- view['pts3d'] = view['pts3d'].swapaxes(0, 1)
307
-
308
- # transpose x and y pixels
309
- view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
310
-
311
- # transpose correspondences x and y
312
- view['corres'] = view['corres'][:, [1, 0]]
313
-
314
-
315
- def rotate_90(views, k=1):
316
- from scipy.spatial.transform import Rotation
317
- # print('rotation =', k)
318
-
319
- RT = np.eye(4, dtype=np.float32)
320
- RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
321
-
322
- for view in views:
323
- view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
324
- view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
325
- view['camera_pose'] = view['camera_pose'] @ RT
326
-
327
- RT2 = np.eye(3, dtype=np.float32)
328
- RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
329
- H, W = view['depthmap'].shape
330
- if k % 4 == 0:
331
- pass
332
- elif k % 4 == 1:
333
- # top-left (0,0) pixel becomes (0,H-1)
334
- RT2[:2, 2] = (0, H - 1)
335
- elif k % 4 == 2:
336
- # top-left (0,0) pixel becomes (W-1,H-1)
337
- RT2[:2, 2] = (W - 1, H - 1)
338
- elif k % 4 == 3:
339
- # top-left (0,0) pixel becomes (W-1,0)
340
- RT2[:2, 2] = (W - 1, 0)
341
- else:
342
- raise ValueError(f'Bad value for {k=}')
343
-
344
- view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
345
- if k % 2 == 1:
346
- K = view['camera_intrinsics']
347
- np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
348
-
349
- pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
350
- view['pts3d'] = pts3d
351
- view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
352
- view['sky_mask'] = np.rot90(view['sky_mask'], k=k).copy()
353
-
354
- view['corres'] = geotrf(RT2, view['corres']).round().astype(view['corres'].dtype)
355
- view['true_shape'] = np.int32((H, W))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/datasets/utils/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/datasets/utils/cropping.py DELETED
@@ -1,219 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # cropping/match extraction
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import mast3r.utils.path_to_dust3r # noqa
9
- from dust3r.utils.device import to_numpy
10
- from dust3r.utils.geometry import inv, geotrf
11
-
12
-
13
- def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
14
- is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)))
15
- pos1 = is_reciprocal1.nonzero()[0]
16
- pos2 = corres_1_to_2[pos1]
17
- if ret_recip:
18
- return is_reciprocal1, pos1, pos2
19
- return pos1, pos2
20
-
21
-
22
- def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0):
23
- view1, view2 = to_numpy((view1, view2))
24
- # project pixels from image1 --> 3d points --> image2 pixels
25
- shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
26
- shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
27
-
28
- # compute reciprocal correspondences:
29
- # pos1 == valid pixels (correspondences) in image1
30
- is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True)
31
- is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)))
32
-
33
- if target_n_corres is None:
34
- if ret_xy:
35
- pos1 = unravel_xy(pos1, shape1)
36
- pos2 = unravel_xy(pos2, shape2)
37
- return pos1, pos2
38
-
39
- available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
40
- target_n_positives = int(target_n_corres * (1 - nneg))
41
- n_positives = min(len(pos1), target_n_positives)
42
- n_negatives = min(target_n_corres - n_positives, available_negatives)
43
-
44
- if n_negatives + n_positives != target_n_corres:
45
- # should be really rare => when there are not enough negatives
46
- # in that case, break nneg and add a few more positives ?
47
- n_positives = target_n_corres - n_negatives
48
- assert n_positives <= len(pos1)
49
-
50
- assert n_positives <= len(pos1)
51
- assert n_positives <= len(pos2)
52
- assert n_negatives <= (~is_reciprocal1).sum()
53
- assert n_negatives <= (~is_reciprocal2).sum()
54
- assert n_positives + n_negatives == target_n_corres
55
-
56
- valid = np.ones(n_positives, dtype=bool)
57
- if n_positives < len(pos1):
58
- # random sub-sampling of valid correspondences
59
- perm = rng.permutation(len(pos1))[:n_positives]
60
- pos1 = pos1[perm]
61
- pos2 = pos2[perm]
62
-
63
- if n_negatives > 0:
64
- # add false correspondences if not enough
65
- def norm(p): return p / p.sum()
66
- pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))]
67
- pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))]
68
- valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
69
-
70
- # convert (x+W*y) back to 2d (x,y) coordinates
71
- if ret_xy:
72
- pos1 = unravel_xy(pos1, shape1)
73
- pos2 = unravel_xy(pos2, shape2)
74
- return pos1, pos2, valid
75
-
76
-
77
- def reproject_view(pts3d, view2):
78
- shape = view2['pts3d'].shape[:2]
79
- return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
80
-
81
-
82
- def reproject(pts3d, K, world2cam, shape):
83
- H, W, THREE = pts3d.shape
84
- assert THREE == 3
85
-
86
- # reproject in camera2 space
87
- with np.errstate(divide='ignore', invalid='ignore'):
88
- pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
89
-
90
- # quantize to pixel positions
91
- return (H, W), ravel_xy(pos, shape)
92
-
93
-
94
- def ravel_xy(pos, shape):
95
- H, W = shape
96
- with np.errstate(invalid='ignore'):
97
- qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
98
- quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
99
- return quantized_pos
100
-
101
-
102
- def unravel_xy(pos, shape):
103
- # convert (x+W*y) back to 2d (x,y) coordinates
104
- return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
105
-
106
-
107
- def _rotation_origin_to_pt(target):
108
- """ Align the origin (0,0,1) with the target point (x,y,1) in projective space.
109
- Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z.
110
- """
111
- from scipy.spatial.transform import Rotation
112
- x, y = target
113
- rot_z = np.arctan2(y, x)
114
- rot_y = np.arctan(np.linalg.norm(target))
115
- R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix()
116
- return R
117
-
118
-
119
- def _dotmv(Trf, pts, ncol=None, norm=False):
120
- assert Trf.ndim >= 2
121
- ncol = ncol or pts.shape[-1]
122
-
123
- # adapt shape if necessary
124
- output_reshape = pts.shape[:-1]
125
- if Trf.ndim >= 3:
126
- n = Trf.ndim - 2
127
- assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
128
- Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
129
-
130
- if pts.ndim > Trf.ndim:
131
- # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
132
- pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
133
- elif pts.ndim == 2:
134
- # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
135
- pts = pts[:, None, :]
136
-
137
- if pts.shape[-1] + 1 == Trf.shape[-1]:
138
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
139
- pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
140
-
141
- elif pts.shape[-1] == Trf.shape[-1]:
142
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
143
- pts = pts @ Trf
144
- else:
145
- pts = Trf @ pts.T
146
- if pts.ndim >= 2:
147
- pts = pts.swapaxes(-1, -2)
148
-
149
- if norm:
150
- pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
151
- if norm != 1:
152
- pts *= norm
153
-
154
- res = pts[..., :ncol].reshape(*output_reshape, ncol)
155
- return res
156
-
157
-
158
- def crop_to_homography(K, crop, target_size=None):
159
- """ Given an image and its intrinsics,
160
- we want to replicate a rectangular crop with an homography,
161
- so that the principal point of the new 'crop' is centered.
162
- """
163
- # build intrinsics for the crop
164
- crop = np.round(crop)
165
- crop_size = crop[2:] - crop[:2]
166
- K2 = K.copy() # same focal
167
- K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered
168
-
169
- # find which corner is the most far-away from current principal point
170
- # so that the final homography does not go over the image borders
171
- corners = crop.reshape(-1, 2)
172
- corner_idx = np.abs(corners - K[:2, 2]).argmax(0)
173
- corner = corners[corner_idx, [0, 1]]
174
- # align with the corresponding corner from the target view
175
- corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx]
176
-
177
- old_pt = _dotmv(np.linalg.inv(K), corner, norm=1)
178
- new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1)
179
- R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt))
180
-
181
- if target_size is not None:
182
- imsize = target_size
183
- target_size = np.asarray(target_size)
184
- scaling = min(target_size / crop_size)
185
- K2[:2] *= scaling
186
- K2[:2, 2] = target_size / 2
187
- else:
188
- imsize = tuple(np.int32(crop_size).tolist())
189
-
190
- return imsize, K2, R, K @ R @ np.linalg.inv(K2)
191
-
192
-
193
- def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random):
194
- """ Generate random crops of size=resolution,
195
- for an input image upscaled to (imsize + randint(0 , aug_crop))
196
- """
197
- resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution)
198
-
199
- # (virtually) upscale the input image
200
- # scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize))
201
- scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize))))
202
- imsize2 = np.int32(np.array(imsize) * scaling)
203
-
204
- # generate some random crops
205
- topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop)
206
- crops = np.c_[topleft, topleft + resolution_crop]
207
- # print(f"{scaling=}, {topleft=}")
208
- # reduce the resolution to come back to original size
209
- crops /= scaling
210
- return crops
211
-
212
-
213
- def in2d_rect(corres, crops):
214
- # corres = (N,2)
215
- # crops = (M,4)
216
- # output = (N, M)
217
- is_sup = (corres[:, None] >= crops[None, :, 0:2])
218
- is_inf = (corres[:, None] < crops[None, :, 2:4])
219
- return (is_sup & is_inf).all(axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/fast_nn.py DELETED
@@ -1,221 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # MASt3R Fast Nearest Neighbor
6
- # --------------------------------------------------------
7
- import torch
8
- import numpy as np
9
- import math
10
- from scipy.spatial import KDTree
11
-
12
- import mast3r.utils.path_to_dust3r # noqa
13
- from ..dust3r.dust3r.utils.device import to_numpy, todevice # noqa
14
-
15
-
16
- @torch.no_grad()
17
- def bruteforce_reciprocal_nns(A, B, device='cuda', block_size=None, dist='l2'):
18
- if isinstance(A, np.ndarray):
19
- A = torch.from_numpy(A).to(device)
20
- if isinstance(B, np.ndarray):
21
- B = torch.from_numpy(B).to(device)
22
-
23
- A = A.to(device)
24
- B = B.to(device)
25
-
26
- if dist == 'l2':
27
- dist_func = torch.cdist
28
- argmin = torch.min
29
- elif dist == 'dot':
30
- def dist_func(A, B):
31
- return A @ B.T
32
-
33
- def argmin(X, dim):
34
- sim, nn = torch.max(X, dim=dim)
35
- return sim.neg_(), nn
36
- else:
37
- raise ValueError(f'Unknown {dist=}')
38
-
39
- if block_size is None or len(A) * len(B) <= block_size**2:
40
- dists = dist_func(A, B)
41
- _, nn_A = argmin(dists, dim=1)
42
- _, nn_B = argmin(dists, dim=0)
43
- else:
44
- dis_A = torch.full((A.shape[0],), float('inf'), device=device, dtype=A.dtype)
45
- dis_B = torch.full((B.shape[0],), float('inf'), device=device, dtype=B.dtype)
46
- nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64)
47
- nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64)
48
- number_of_iteration_A = math.ceil(A.shape[0] / block_size)
49
- number_of_iteration_B = math.ceil(B.shape[0] / block_size)
50
-
51
- for i in range(number_of_iteration_A):
52
- A_i = A[i * block_size:(i + 1) * block_size]
53
- for j in range(number_of_iteration_B):
54
- B_j = B[j * block_size:(j + 1) * block_size]
55
- dists_blk = dist_func(A_i, B_j) # A, B, 1
56
- # dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size]
57
- min_A_i, argmin_A_i = argmin(dists_blk, dim=1)
58
- min_B_j, argmin_B_j = argmin(dists_blk, dim=0)
59
-
60
- col_mask = min_A_i < dis_A[i * block_size:(i + 1) * block_size]
61
- line_mask = min_B_j < dis_B[j * block_size:(j + 1) * block_size]
62
-
63
- dis_A[i * block_size:(i + 1) * block_size][col_mask] = min_A_i[col_mask]
64
- dis_B[j * block_size:(j + 1) * block_size][line_mask] = min_B_j[line_mask]
65
-
66
- nn_A[i * block_size:(i + 1) * block_size][col_mask] = argmin_A_i[col_mask] + (j * block_size)
67
- nn_B[j * block_size:(j + 1) * block_size][line_mask] = argmin_B_j[line_mask] + (i * block_size)
68
- nn_A = nn_A.cpu().numpy()
69
- nn_B = nn_B.cpu().numpy()
70
- return nn_A, nn_B
71
-
72
-
73
- class cdistMatcher:
74
- def __init__(self, db_pts, device='cuda'):
75
- self.db_pts = db_pts.to(device)
76
- self.device = device
77
-
78
- def query(self, queries, k=1, **kw):
79
- assert k == 1
80
- if queries.numel() == 0:
81
- return None, []
82
- nnA, nnB = bruteforce_reciprocal_nns(queries, self.db_pts, device=self.device, **kw)
83
- dis = None
84
- return dis, nnA
85
-
86
-
87
- def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False):
88
- assert idx1.dtype == idx2.dtype == np.int32
89
-
90
- # unique and sort along idx1
91
- corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index)
92
- if ret_index:
93
- corres, indices = corres
94
- xy2, xy1 = corres[:, None].view(np.int32).T
95
-
96
- if ret_xy:
97
- assert shape1 and shape2
98
- xy1 = np.unravel_index(xy1, shape1)
99
- xy2 = np.unravel_index(xy2, shape2)
100
- if ret_xy != 'y_x':
101
- xy1 = xy1[0].base[:, ::-1]
102
- xy2 = xy2[0].base[:, ::-1]
103
-
104
- if ret_index:
105
- return xy1, xy2, indices
106
- return xy1, xy2
107
-
108
-
109
- def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False,
110
- device='cuda', **matcher_kw):
111
- H1, W1, DIM1 = pts1.shape
112
- H2, W2, DIM2 = pts2.shape
113
- assert DIM1 == DIM2
114
-
115
- pts1 = pts1.reshape(-1, DIM1)
116
- pts2 = pts2.reshape(-1, DIM2)
117
-
118
- if isinstance(subsample_or_initxy1, int) and pixel_tol == 0:
119
- S = subsample_or_initxy1
120
- y1, x1 = np.mgrid[S // 2:H1:S, S // 2:W1:S].reshape(2, -1)
121
- max_iter = 10
122
- else:
123
- x1, y1 = subsample_or_initxy1
124
- if isinstance(x1, torch.Tensor):
125
- x1 = x1.cpu().numpy()
126
- if isinstance(y1, torch.Tensor):
127
- y1 = y1.cpu().numpy()
128
- max_iter = 1
129
-
130
- xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons
131
- xy2 = np.full_like(xy1, -1)
132
- old_xy1 = xy1.copy()
133
- old_xy2 = xy2.copy()
134
-
135
- if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type.startswith('cuda')):
136
- pts1 = pts1.to(device)
137
- pts2 = pts2.to(device)
138
- tree1 = cdistMatcher(pts1, device=device)
139
- tree2 = cdistMatcher(pts2, device=device)
140
- else:
141
- pts1, pts2 = to_numpy((pts1, pts2))
142
- tree1 = KDTree(pts1)
143
- tree2 = KDTree(pts2)
144
-
145
- notyet = np.ones(len(xy1), dtype=bool)
146
- basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None
147
-
148
- niter = 0
149
- # n_notyet = [len(notyet)]
150
- while notyet.any():
151
- _, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw))
152
- if not ret_basin:
153
- notyet &= (old_xy2 != xy2) # remove points that have converged
154
-
155
- _, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw))
156
- if ret_basin:
157
- basin[old_xy1[notyet]] = xy1[notyet]
158
- notyet &= (old_xy1 != xy1) # remove points that have converged
159
-
160
- # n_notyet.append(notyet.sum())
161
- niter += 1
162
- if niter >= max_iter:
163
- break
164
-
165
- old_xy2[:] = xy2
166
- old_xy1[:] = xy1
167
-
168
- # print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter])))
169
-
170
- if pixel_tol > 0:
171
- # in case we only want to match some specific points
172
- # and still have some way of checking reciprocity
173
- old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base
174
- new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base
175
- dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1)
176
- converged = dis < pixel_tol
177
- if not isinstance(subsample_or_initxy1, int):
178
- xy1 = old_xy1 # replace new points by old ones
179
- else:
180
- converged = ~notyet # converged correspondences
181
-
182
- # keep only unique correspondences, and sort on xy1
183
- xy1, xy2 = merge_corres(xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy)
184
- if ret_basin:
185
- return xy1, xy2, basin
186
- return xy1, xy2
187
-
188
-
189
- def extract_correspondences_nonsym(A, B, confA, confB, subsample=8, device=None, ptmap_key='pred_desc', pixel_tol=0):
190
- if '3d' in ptmap_key:
191
- opt = dict(device='cpu', workers=32)
192
- else:
193
- opt = dict(device=device, dist='dot', block_size=2**13)
194
-
195
- # matching the two pairs
196
- idx1 = []
197
- idx2 = []
198
- # merge corres from opposite pairs
199
- HA, WA = A.shape[:2]
200
- HB, WB = B.shape[:2]
201
- if pixel_tol == 0:
202
- nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
203
- nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
204
- else:
205
- S = subsample
206
- yA, xA = np.mgrid[S // 2:HA:S, S // 2:WA:S].reshape(2, -1)
207
- yB, xB = np.mgrid[S // 2:HB:S, S // 2:WB:S].reshape(2, -1)
208
-
209
- nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt)
210
- nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt)
211
-
212
- idx1 = np.r_[nn1to2[0], nn2to1[1]]
213
- idx2 = np.r_[nn1to2[1], nn2to1[0]]
214
-
215
- c1 = confA.ravel()[idx1]
216
- c2 = confB.ravel()[idx2]
217
-
218
- xy1, xy2, idx = merge_corres(idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True)
219
- conf = np.minimum(c1[idx], c2[idx])
220
- corres = (xy1.copy(), xy2.copy(), conf)
221
- return todevice(corres, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/losses.py DELETED
@@ -1,514 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Implementation of MASt3R training losses
6
- # --------------------------------------------------------
7
- import torch
8
- import torch.nn as nn
9
- import numpy as np
10
- from sklearn.metrics import average_precision_score
11
-
12
- import mast3r.utils.path_to_dust3r # noqa
13
- from dust3r.losses import BaseCriterion, Criterion, MultiLoss, Sum, ConfLoss
14
- from dust3r.losses import Regr3D as Regr3D_dust3r
15
- from dust3r.utils.geometry import (geotrf, inv, normalize_pointcloud)
16
- from dust3r.inference import get_pred_pts3d
17
- from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
18
-
19
-
20
- def apply_log_to_norm(xyz):
21
- d = xyz.norm(dim=-1, keepdim=True)
22
- xyz = xyz / d.clip(min=1e-8)
23
- xyz = xyz * torch.log1p(d)
24
- return xyz
25
-
26
-
27
- class Regr3D (Regr3D_dust3r):
28
- def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, opt_fit_gt=False,
29
- sky_loss_value=2, max_metric_scale=False, loss_in_log=False):
30
- self.loss_in_log = loss_in_log
31
- if norm_mode.startswith('?'):
32
- # do no norm pts from metric scale datasets
33
- self.norm_all = False
34
- self.norm_mode = norm_mode[1:]
35
- else:
36
- self.norm_all = True
37
- self.norm_mode = norm_mode
38
- super().__init__(criterion, self.norm_mode, gt_scale)
39
-
40
- self.sky_loss_value = sky_loss_value
41
- self.max_metric_scale = max_metric_scale
42
-
43
- def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
44
- # everything is normalized w.r.t. camera of view1
45
- in_camera1 = inv(gt1['camera_pose'])
46
- gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
47
- gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
48
-
49
- valid1 = gt1['valid_mask'].clone()
50
- valid2 = gt2['valid_mask'].clone()
51
-
52
- if dist_clip is not None:
53
- # points that are too far-away == invalid
54
- dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
55
- dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
56
- valid1 = valid1 & (dis1 <= dist_clip)
57
- valid2 = valid2 & (dis2 <= dist_clip)
58
-
59
- if self.loss_in_log == 'before':
60
- # this only make sense when depth_mode == 'linear'
61
- gt_pts1 = apply_log_to_norm(gt_pts1)
62
- gt_pts2 = apply_log_to_norm(gt_pts2)
63
-
64
- pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False).clone()
65
- pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True).clone()
66
-
67
- if not self.norm_all:
68
- if self.max_metric_scale:
69
- B = valid1.shape[0]
70
- # valid1: B, H, W
71
- # torch.linalg.norm(gt_pts1, dim=-1) -> B, H, W
72
- # dist1_to_cam1 -> reshape to B, H*W
73
- dist1_to_cam1 = torch.where(valid1, torch.linalg.norm(gt_pts1, dim=-1), 0).view(B, -1)
74
- dist2_to_cam1 = torch.where(valid2, torch.linalg.norm(gt_pts2, dim=-1), 0).view(B, -1)
75
-
76
- # is_metric_scale: B
77
- # dist1_to_cam1.max(dim=-1).values -> B
78
- gt1['is_metric_scale'] = gt1['is_metric_scale'] \
79
- & (dist1_to_cam1.max(dim=-1).values < self.max_metric_scale) \
80
- & (dist2_to_cam1.max(dim=-1).values < self.max_metric_scale)
81
- gt2['is_metric_scale'] = gt1['is_metric_scale']
82
-
83
- mask = ~gt1['is_metric_scale']
84
- else:
85
- mask = torch.ones_like(gt1['is_metric_scale'])
86
- # normalize 3d points
87
- if self.norm_mode and mask.any():
88
- pr_pts1[mask], pr_pts2[mask] = normalize_pointcloud(pr_pts1[mask], pr_pts2[mask], self.norm_mode,
89
- valid1[mask], valid2[mask])
90
-
91
- if self.norm_mode and not self.gt_scale:
92
- gt_pts1, gt_pts2, norm_factor = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode,
93
- valid1, valid2, ret_factor=True)
94
- # apply the same normalization to prediction
95
- pr_pts1[~mask] = pr_pts1[~mask] / norm_factor[~mask]
96
- pr_pts2[~mask] = pr_pts2[~mask] / norm_factor[~mask]
97
-
98
- # return sky segmentation, making sure they don't include any labelled 3d points
99
- sky1 = gt1['sky_mask'] & (~valid1)
100
- sky2 = gt2['sky_mask'] & (~valid2)
101
- return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, sky1, sky2, {}
102
-
103
- def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
104
- gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
105
- self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
106
-
107
- if self.sky_loss_value > 0:
108
- assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
109
- # add the sky pixel as "valid" pixels...
110
- mask1 = mask1 | sky1
111
- mask2 = mask2 | sky2
112
-
113
- # loss on img1 side
114
- pred_pts1 = pred_pts1[mask1]
115
- gt_pts1 = gt_pts1[mask1]
116
- if self.loss_in_log and self.loss_in_log != 'before':
117
- # this only make sense when depth_mode == 'exp'
118
- pred_pts1 = apply_log_to_norm(pred_pts1)
119
- gt_pts1 = apply_log_to_norm(gt_pts1)
120
- l1 = self.criterion(pred_pts1, gt_pts1)
121
-
122
- # loss on gt2 side
123
- pred_pts2 = pred_pts2[mask2]
124
- gt_pts2 = gt_pts2[mask2]
125
- if self.loss_in_log and self.loss_in_log != 'before':
126
- pred_pts2 = apply_log_to_norm(pred_pts2)
127
- gt_pts2 = apply_log_to_norm(gt_pts2)
128
- l2 = self.criterion(pred_pts2, gt_pts2)
129
-
130
- if self.sky_loss_value > 0:
131
- assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
132
- # ... but force the loss to be high there
133
- l1 = torch.where(sky1[mask1], self.sky_loss_value, l1)
134
- l2 = torch.where(sky2[mask2], self.sky_loss_value, l2)
135
- self_name = type(self).__name__
136
- details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())}
137
- return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
138
-
139
-
140
- class Regr3D_ShiftInv (Regr3D):
141
- """ Same than Regr3D but invariant to depth shift.
142
- """
143
-
144
- def get_all_pts3d(self, gt1, gt2, pred1, pred2):
145
- # compute unnormalized points
146
- gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
147
- super().get_all_pts3d(gt1, gt2, pred1, pred2)
148
-
149
- # compute median depth
150
- gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
151
- pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
152
- gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
153
- pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
154
-
155
- # subtract the median depth
156
- gt_z1 -= gt_shift_z
157
- gt_z2 -= gt_shift_z
158
- pred_z1 -= pred_shift_z
159
- pred_z2 -= pred_shift_z
160
-
161
- # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
162
- return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
163
-
164
-
165
- class Regr3D_ScaleInv (Regr3D):
166
- """ Same than Regr3D but invariant to depth scale.
167
- if gt_scale == True: enforce the prediction to take the same scale than GT
168
- """
169
-
170
- def get_all_pts3d(self, gt1, gt2, pred1, pred2):
171
- # compute depth-normalized points
172
- gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
173
- super().get_all_pts3d(gt1, gt2, pred1, pred2)
174
-
175
- # measure scene scale
176
- _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
177
- _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
178
-
179
- # prevent predictions to be in a ridiculous range
180
- pred_scale = pred_scale.clip(min=1e-3, max=1e3)
181
-
182
- # subtract the median depth
183
- if self.gt_scale:
184
- pred_pts1 *= gt_scale / pred_scale
185
- pred_pts2 *= gt_scale / pred_scale
186
- # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
187
- else:
188
- gt_pts1 /= gt_scale
189
- gt_pts2 /= gt_scale
190
- pred_pts1 /= pred_scale
191
- pred_pts2 /= pred_scale
192
- # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
193
-
194
- return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
195
-
196
-
197
- class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
198
- # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
199
- pass
200
-
201
-
202
- def get_similarities(desc1, desc2, euc=False):
203
- if euc: # euclidean distance in same range than similarities
204
- dists = (desc1[:, :, None] - desc2[:, None]).norm(dim=-1)
205
- sim = 1 / (1 + dists)
206
- else:
207
- # Compute similarities
208
- sim = desc1 @ desc2.transpose(-2, -1)
209
- return sim
210
-
211
-
212
- class MatchingCriterion(BaseCriterion):
213
- def __init__(self, reduction='mean', fp=torch.float32):
214
- super().__init__(reduction)
215
- self.fp = fp
216
-
217
- def forward(self, a, b, valid_matches=None, euc=False):
218
- assert a.ndim >= 2 and 1 <= a.shape[-1], f'Bad shape = {a.shape}'
219
- dist = self.loss(a.to(self.fp), b.to(self.fp), valid_matches, euc=euc)
220
- # one dimension less or reduction to single value
221
- assert (valid_matches is None and dist.ndim == a.ndim -
222
- 1) or self.reduction in ['mean', 'sum', '1-mean', 'none']
223
- if self.reduction == 'none':
224
- return dist
225
- if self.reduction == 'sum':
226
- return dist.sum()
227
- if self.reduction == 'mean':
228
- return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
229
- if self.reduction == '1-mean':
230
- return 1. - dist.mean() if dist.numel() > 0 else dist.new_ones(())
231
- raise ValueError(f'bad {self.reduction=} mode')
232
-
233
- def loss(self, a, b, valid_matches=None):
234
- raise NotImplementedError
235
-
236
-
237
- class InfoNCE(MatchingCriterion):
238
- def __init__(self, temperature=0.07, eps=1e-8, mode='all', **kwargs):
239
- super().__init__(**kwargs)
240
- self.temperature = temperature
241
- self.eps = eps
242
- assert mode in ['all', 'proper', 'dual']
243
- self.mode = mode
244
-
245
- def loss(self, desc1, desc2, valid_matches=None, euc=False):
246
- # valid positives are along diagonals
247
- B, N, D = desc1.shape
248
- B2, N2, D2 = desc2.shape
249
- assert B == B2 and D == D2
250
- if valid_matches is None:
251
- valid_matches = torch.ones([B, N], dtype=bool)
252
- # torch.all(valid_matches.sum(dim=-1) > 0) some pairs have no matches????
253
- assert valid_matches.shape == torch.Size([B, N]) and valid_matches.sum() > 0
254
-
255
- # Tempered similarities
256
- sim = get_similarities(desc1, desc2, euc) / self.temperature
257
- sim[sim.isnan()] = -torch.inf # ignore nans
258
- # Softmax of positives with temperature
259
- sim = sim.exp_() # save peak memory
260
- positives = sim.diagonal(dim1=-2, dim2=-1)
261
-
262
- # Loss
263
- if self.mode == 'all': # Previous InfoNCE
264
- loss = -torch.log((positives / sim.sum(dim=-1).sum(dim=-1, keepdim=True)).clip(self.eps))
265
- elif self.mode == 'proper': # Proper InfoNCE
266
- loss = -(torch.log((positives / sim.sum(dim=-2)).clip(self.eps)) +
267
- torch.log((positives / sim.sum(dim=-1)).clip(self.eps)))
268
- elif self.mode == 'dual': # Dual Softmax
269
- loss = -(torch.log((positives**2 / sim.sum(dim=-1) / sim.sum(dim=-2)).clip(self.eps)))
270
- else:
271
- raise ValueError("This should not happen...")
272
- return loss[valid_matches]
273
-
274
-
275
- class APLoss (MatchingCriterion):
276
- """ AP loss.
277
-
278
- Input: (N, M) values in [min, max]
279
- label: (N, M) values in {0, 1}
280
-
281
- Returns: 1 - mAP (mean AP for each n in {1..N})
282
- Note: typically, this is what you wanna minimize
283
- """
284
-
285
- def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
286
- super().__init__(**kw)
287
- # Exact/True AP loss (not differentiable)
288
- if nq == 0:
289
- nq = 'sklearn' # special case
290
- try:
291
- self.compute_AP = eval('self.compute_true_AP_' + nq)
292
- except:
293
- raise ValueError("Unknown mode %s for AP loss" % nq)
294
-
295
- @staticmethod
296
- def compute_true_AP_sklearn(scores, labels):
297
- def compute_AP(label, score):
298
- return average_precision_score(label, score)
299
-
300
- aps = scores.new_zeros((scores.shape[0], scores.shape[1]))
301
- label_np = labels.cpu().numpy().astype(bool)
302
- scores_np = scores.cpu().numpy()
303
- for bi in range(scores_np.shape[0]):
304
- for i in range(scores_np.shape[1]):
305
- labels = label_np[bi, i, :]
306
- if labels.sum() < 1:
307
- continue
308
- aps[bi, i] = compute_AP(labels, scores_np[bi, i, :])
309
- return aps
310
-
311
- @staticmethod
312
- def compute_true_AP_torch(scores, labels):
313
- assert scores.shape == labels.shape
314
- B, N, M = labels.shape
315
- dev = labels.device
316
- with torch.no_grad():
317
- # sort scores
318
- _, order = scores.sort(dim=-1, descending=True)
319
- # sort labels accordingly
320
- labels = labels[torch.arange(B, device=dev)[:, None, None].expand(order.shape),
321
- torch.arange(N, device=dev)[None, :, None].expand(order.shape),
322
- order]
323
- # compute number of positives per query
324
- npos = labels.sum(dim=-1)
325
- assert torch.all(torch.isclose(npos, npos[0, 0])
326
- ), "only implemented for constant number of positives per query"
327
- npos = int(npos[0, 0])
328
- # compute precision at each recall point
329
- posrank = labels.nonzero()[:, -1].view(B, N, npos)
330
- recall = torch.arange(1, 1 + npos, dtype=torch.float32, device=dev)[None, None, :].expand(B, N, npos)
331
- precision = recall / (1 + posrank).float()
332
- # average precision values at all recall points
333
- aps = precision.mean(dim=-1)
334
-
335
- return aps
336
-
337
- def loss(self, desc1, desc2, valid_matches=None, euc=False): # if matches is None, positives are the diagonal
338
- B, N1, D = desc1.shape
339
- B2, N2, D2 = desc2.shape
340
- assert B == B2 and D == D2
341
-
342
- scores = get_similarities(desc1, desc2, euc)
343
-
344
- labels = torch.zeros([B, N1, N2], dtype=scores.dtype, device=scores.device)
345
-
346
- # allow all diagonal positives and only mask afterwards
347
- labels.diagonal(dim1=-2, dim2=-1)[...] = 1.
348
- apscore = self.compute_AP(scores, labels)
349
- if valid_matches is not None:
350
- apscore = apscore[valid_matches]
351
- return apscore
352
-
353
-
354
- class MatchingLoss (Criterion, MultiLoss):
355
- """
356
- Matching loss per image
357
- only compare pixels inside an image but not in the whole batch as what would be done usually
358
- """
359
-
360
- def __init__(self, criterion, withconf=False, use_pts3d=False, negatives_padding=0, blocksize=4096):
361
- super().__init__(criterion)
362
- self.negatives_padding = negatives_padding
363
- self.use_pts3d = use_pts3d
364
- self.blocksize = blocksize
365
- self.withconf = withconf
366
-
367
- def add_negatives(self, outdesc2, desc2, batchid, x2, y2):
368
- if self.negatives_padding:
369
- B, H, W, D = desc2.shape
370
- negatives = torch.ones([B, H, W], device=desc2.device, dtype=bool)
371
- negatives[batchid, y2, x2] = False
372
- sel = negatives & (negatives.view([B, -1]).cumsum(dim=-1).view(B, H, W)
373
- <= self.negatives_padding) # take the N-first negatives
374
- outdesc2 = torch.cat([outdesc2, desc2[sel].view([B, -1, D])], dim=1)
375
- return outdesc2
376
-
377
- def get_confs(self, pred1, pred2, sel1, sel2):
378
- if self.withconf:
379
- if self.use_pts3d:
380
- outconfs1 = pred1['conf'][sel1]
381
- outconfs2 = pred2['conf'][sel2]
382
- else:
383
- outconfs1 = pred1['desc_conf'][sel1]
384
- outconfs2 = pred2['desc_conf'][sel2]
385
- else:
386
- outconfs1 = outconfs2 = None
387
- return outconfs1, outconfs2
388
-
389
- def get_descs(self, pred1, pred2):
390
- if self.use_pts3d:
391
- desc1, desc2 = pred1['pts3d'], pred2['pts3d_in_other_view']
392
- else:
393
- desc1, desc2 = pred1['desc'], pred2['desc']
394
- return desc1, desc2
395
-
396
- def get_matching_descs(self, gt1, gt2, pred1, pred2, **kw):
397
- outdesc1 = outdesc2 = outconfs1 = outconfs2 = None
398
- # Recover descs, GT corres and valid mask
399
- desc1, desc2 = self.get_descs(pred1, pred2)
400
-
401
- (x1, y1), (x2, y2) = gt1['corres'].unbind(-1), gt2['corres'].unbind(-1)
402
- valid_matches = gt1['valid_corres']
403
-
404
- # Select descs that have GT matches
405
- B, N = x1.shape
406
- batchid = torch.arange(B)[:, None].repeat(1, N) # B, N
407
- outdesc1, outdesc2 = desc1[batchid, y1, x1], desc2[batchid, y2, x2] # B, N, D
408
-
409
- # Padd with unused negatives
410
- outdesc2 = self.add_negatives(outdesc2, desc2, batchid, x2, y2)
411
-
412
- # Gather confs if needed
413
- sel1 = batchid, y1, x1
414
- sel2 = batchid, y2, x2
415
- outconfs1, outconfs2 = self.get_confs(pred1, pred2, sel1, sel2)
416
-
417
- return outdesc1, outdesc2, outconfs1, outconfs2, valid_matches, {'use_euclidean_dist': self.use_pts3d}
418
-
419
- def blockwise_criterion(self, descs1, descs2, confs1, confs2, valid_matches, euc, rng=np.random, shuffle=True):
420
- loss = None
421
- details = {}
422
- B, N, D = descs1.shape
423
-
424
- if N <= self.blocksize: # Blocks are larger than provided descs, compute regular loss
425
- loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
426
- else: # Compute criterion on the blockdiagonal only, after shuffling
427
- # Shuffle if necessary
428
- matches_perm = slice(None)
429
- if shuffle:
430
- matches_perm = np.stack([rng.choice(range(N), size=N, replace=False) for _ in range(B)])
431
- batchid = torch.tile(torch.arange(B), (N, 1)).T
432
- matches_perm = batchid, matches_perm
433
-
434
- descs1 = descs1[matches_perm]
435
- descs2 = descs2[matches_perm]
436
- valid_matches = valid_matches[matches_perm]
437
-
438
- assert N % self.blocksize == 0, "Error, can't chunk block-diagonal, please check blocksize"
439
- n_chunks = N // self.blocksize
440
- descs1 = descs1.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
441
- descs2 = descs2.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
442
- valid_matches = valid_matches.view([B * n_chunks, self.blocksize])
443
- loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
444
- if self.withconf:
445
- confs1, confs2 = map(lambda x: x[matches_perm], (confs1, confs2)) # apply perm to confidences if needed
446
-
447
- if self.withconf:
448
- # split confidences between positives/negatives for loss computation
449
- details['conf_pos'] = map(lambda x: x[valid_matches.view(B, -1)], (confs1, confs2))
450
- details['conf_neg'] = map(lambda x: x[~valid_matches.view(B, -1)], (confs1, confs2))
451
- details['Conf1_std'] = confs1.std()
452
- details['Conf2_std'] = confs2.std()
453
-
454
- return loss, details
455
-
456
- def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
457
- # Gather preds and GT
458
- descs1, descs2, confs1, confs2, valid_matches, monitoring = self.get_matching_descs(
459
- gt1, gt2, pred1, pred2, **kw)
460
-
461
- # loss on matches
462
- loss, details = self.blockwise_criterion(descs1, descs2, confs1, confs2,
463
- valid_matches, euc=monitoring.pop('use_euclidean_dist', False))
464
-
465
- details[type(self).__name__] = float(loss.mean())
466
- return loss, (details | monitoring)
467
-
468
-
469
- class ConfMatchingLoss(ConfLoss):
470
- """ Weight matching by learned confidence. Same as ConfLoss but for a matching criterion
471
- Assuming the input matching_loss is a match-level loss.
472
- """
473
-
474
- def __init__(self, pixel_loss, alpha=1., confmode='prod', neg_conf_loss_quantile=False):
475
- super().__init__(pixel_loss, alpha)
476
- self.pixel_loss.withconf = True
477
- self.confmode = confmode
478
- self.neg_conf_loss_quantile = neg_conf_loss_quantile
479
-
480
- def aggregate_confs(self, confs1, confs2): # get the confidences resulting from the two view predictions
481
- if self.confmode == 'prod':
482
- confs = confs1 * confs2 if confs1 is not None and confs2 is not None else 1.
483
- elif self.confmode == 'mean':
484
- confs = .5 * (confs1 + confs2) if confs1 is not None and confs2 is not None else 1.
485
- else:
486
- raise ValueError(f"Unknown conf mode {self.confmode}")
487
- return confs
488
-
489
- def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
490
- # compute per-pixel loss
491
- loss, details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
492
- # Recover confidences for positive and negative samples
493
- conf1_pos, conf2_pos = details.pop('conf_pos')
494
- conf1_neg, conf2_neg = details.pop('conf_neg')
495
- conf_pos = self.aggregate_confs(conf1_pos, conf2_pos)
496
-
497
- # weight Matching loss by confidence on positives
498
- conf_pos, log_conf_pos = self.get_conf_log(conf_pos)
499
- conf_loss = loss * conf_pos - self.alpha * log_conf_pos
500
- # average + nan protection (in case of no valid pixels at all)
501
- conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
502
- # Add negative confs loss to give some supervision signal to confidences for pixels that are not matched in GT
503
- if self.neg_conf_loss_quantile:
504
- conf_neg = torch.cat([conf1_neg, conf2_neg])
505
- conf_neg, log_conf_neg = self.get_conf_log(conf_neg)
506
-
507
- # recover quantile that will be used for negatives loss value assignment
508
- neg_loss_value = torch.quantile(loss, self.neg_conf_loss_quantile).detach()
509
- neg_loss = neg_loss_value * conf_neg - self.alpha * log_conf_neg
510
-
511
- neg_loss = neg_loss.mean() if neg_loss.numel() > 0 else 0
512
- conf_loss = conf_loss + neg_loss
513
-
514
- return conf_loss, dict(matching_conf_loss=float(conf_loss), **details)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/model.py DELETED
@@ -1,68 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # MASt3R model class
6
- # --------------------------------------------------------
7
- import torch
8
- import torch.nn.functional as F
9
- import os
10
-
11
- from mast3r.catmlp_dpt_head import mast3r_head_factory
12
-
13
- import mast3r.utils.path_to_dust3r # noqa
14
- from ..dust3r.dust3r.model import AsymmetricCroCo3DStereo # noqa
15
- from ..dust3r.dust3r.utils.misc import transpose_to_landscape # noqa
16
-
17
-
18
- inf = float('inf')
19
-
20
-
21
- def load_model(model_path, device, verbose=True):
22
- if verbose:
23
- print('... loading model from', model_path)
24
- ckpt = torch.load(model_path, map_location='cpu')
25
- args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
26
- if 'landscape_only' not in args:
27
- args = args[:-1] + ', landscape_only=False)'
28
- else:
29
- args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
30
- assert "landscape_only=False" in args
31
- if verbose:
32
- print(f"instantiating : {args}")
33
- net = eval(args)
34
- s = net.load_state_dict(ckpt['model'], strict=False)
35
- if verbose:
36
- print(s)
37
- return net.to(device)
38
-
39
-
40
- class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
41
- def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
42
- self.desc_mode = desc_mode
43
- self.two_confs = two_confs
44
- self.desc_conf_mode = desc_conf_mode
45
- super().__init__(**kwargs)
46
-
47
- @classmethod
48
- def from_pretrained(cls, pretrained_model_name_or_path, **kw):
49
- if os.path.isfile(pretrained_model_name_or_path):
50
- return load_model(pretrained_model_name_or_path, device='cpu')
51
- else:
52
- return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
53
-
54
- def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
55
- # assert img_size[0] % patch_size == 0 and img_size[
56
- # 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
57
- self.output_mode = output_mode
58
- self.head_type = head_type
59
- self.depth_mode = depth_mode
60
- self.conf_mode = conf_mode
61
- if self.desc_conf_mode is None:
62
- self.desc_conf_mode = conf_mode
63
- # allocate heads
64
- self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
65
- self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
66
- # magic wrapper
67
- self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
68
- self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/utils/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
 
 
 
mast3r/utils/coarse_to_fine.py DELETED
@@ -1,214 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # coarse to fine utilities
6
- # --------------------------------------------------------
7
- import numpy as np
8
-
9
-
10
- def crop_tag(cell):
11
- return f'[{cell[1]}:{cell[3]},{cell[0]}:{cell[2]}]'
12
-
13
-
14
- def crop_slice(cell):
15
- return slice(cell[1], cell[3]), slice(cell[0], cell[2])
16
-
17
-
18
- def _start_pos(total_size, win_size, overlap):
19
- # we must have AT LEAST overlap between segments
20
- # first segment starts at 0, last segment starts at total_size-win_size
21
- assert 0 <= overlap < 1
22
- assert total_size >= win_size
23
- spacing = win_size * (1 - overlap)
24
- last_pt = total_size - win_size
25
- n_windows = 2 + int((last_pt - 1) // spacing)
26
- return np.linspace(0, last_pt, n_windows).round().astype(int)
27
-
28
-
29
- def multiple_of_16(x):
30
- return (x // 16) * 16
31
-
32
-
33
- def _make_overlapping_grid(H, W, size, overlap):
34
- H_win = multiple_of_16(H * size // max(H, W))
35
- W_win = multiple_of_16(W * size // max(H, W))
36
- x = _start_pos(W, W_win, overlap)
37
- y = _start_pos(H, H_win, overlap)
38
- grid = np.stack(np.meshgrid(x, y, indexing='xy'), axis=-1)
39
- grid = np.concatenate((grid, grid + (W_win, H_win)), axis=-1)
40
- return grid.reshape(-1, 4)
41
-
42
-
43
- def _cell_size(cell2):
44
- width, height = cell2[:, 2] - cell2[:, 0], cell2[:, 3] - cell2[:, 1]
45
- assert width.min() >= 0
46
- assert height.min() >= 0
47
- return width, height
48
-
49
-
50
- def _norm_windows(cell2, H2, W2, forced_resolution=None):
51
- # make sure the window aspect ratio is 3/4, or the output resolution is forced_resolution if defined
52
- outcell = cell2.copy()
53
- width, height = _cell_size(cell2)
54
- width2, height2 = width.clip(max=W2), height.clip(max=H2)
55
- if forced_resolution is None:
56
- width2[width < height] = (height2[width < height] * 3.01 / 4).clip(max=W2)
57
- height2[width >= height] = (width2[width >= height] * 3.01 / 4).clip(max=H2)
58
- else:
59
- forced_H, forced_W = forced_resolution
60
- width2[:] = forced_W
61
- height2[:] = forced_H
62
-
63
- half = (width2 - width) / 2
64
- outcell[:, 0] -= half
65
- outcell[:, 2] += half
66
- half = (height2 - height) / 2
67
- outcell[:, 1] -= half
68
- outcell[:, 3] += half
69
-
70
- # proj to integers
71
- outcell = np.floor(outcell).astype(int)
72
- # Take care of flooring errors
73
- tmpw, tmph = _cell_size(outcell)
74
- outcell[:, 0] += tmpw.astype(tmpw.dtype) - width2.astype(tmpw.dtype)
75
- outcell[:, 1] += tmph.astype(tmpw.dtype) - height2.astype(tmpw.dtype)
76
-
77
- # make sure 0 <= x < W2 and 0 <= y < H2
78
- outcell[:, 0::2] -= outcell[:, [0]].clip(max=0)
79
- outcell[:, 1::2] -= outcell[:, [1]].clip(max=0)
80
- outcell[:, 0::2] -= outcell[:, [2]].clip(min=W2) - W2
81
- outcell[:, 1::2] -= outcell[:, [3]].clip(min=H2) - H2
82
-
83
- width, height = _cell_size(outcell)
84
- assert np.all(width == width2.astype(width.dtype)) and np.all(
85
- height == height2.astype(height.dtype)), "Error, output is not of the expected shape."
86
- assert np.all(width <= W2)
87
- assert np.all(height <= H2)
88
- return outcell
89
-
90
-
91
- def _weight_pixels(cell, pix, assigned, gauss_var=2):
92
- center = cell.reshape(-1, 2, 2).mean(axis=1)
93
- width, height = _cell_size(cell)
94
-
95
- # square distance between each cell center and each point
96
- dist = (center[:, None] - pix[None]) / np.c_[width, height][:, None]
97
- dist2 = np.square(dist).sum(axis=-1)
98
-
99
- assert assigned.shape == dist2.shape
100
- res = np.where(assigned, np.exp(-gauss_var * dist2), 0)
101
- return res
102
-
103
-
104
- def pos2d_in_rect(p1, cell1):
105
- x, y = p1.T
106
- l, t, r, b = cell1
107
- assigned = (l <= x) & (x < r) & (t <= y) & (y < b)
108
- return assigned
109
-
110
-
111
- def _score_cell(cell1, H2, W2, p1, p2, min_corres=10, forced_resolution=None):
112
- assert p1.shape == p2.shape
113
-
114
- # compute keypoint assignment
115
- assigned = pos2d_in_rect(p1, cell1[None].T)
116
- assert assigned.shape == (len(cell1), len(p1))
117
-
118
- # remove cells without correspondences
119
- valid_cells = assigned.sum(axis=1) >= min_corres
120
- cell1 = cell1[valid_cells]
121
- assigned = assigned[valid_cells]
122
- if not valid_cells.any():
123
- return cell1, cell1, assigned
124
-
125
- # fill-in the assigned points in both image
126
- assigned_p1 = np.empty((len(cell1), len(p1), 2), dtype=np.float32)
127
- assigned_p2 = np.empty((len(cell1), len(p2), 2), dtype=np.float32)
128
- assigned_p1[:] = p1[None]
129
- assigned_p2[:] = p2[None]
130
- assigned_p1[~assigned] = np.nan
131
- assigned_p2[~assigned] = np.nan
132
-
133
- # find the median center and scale of assigned points in each cell
134
- # cell_center1 = np.nanmean(assigned_p1, axis=1)
135
- cell_center2 = np.nanmean(assigned_p2, axis=1)
136
- im1_q25, im1_q75 = np.nanquantile(assigned_p1, (0.1, 0.9), axis=1)
137
- im2_q25, im2_q75 = np.nanquantile(assigned_p2, (0.1, 0.9), axis=1)
138
-
139
- robust_std1 = (im1_q75 - im1_q25).clip(20.)
140
- robust_std2 = (im2_q75 - im2_q25).clip(20.)
141
-
142
- cell_size1 = (cell1[:, 2:4] - cell1[:, 0:2])
143
- cell_size2 = cell_size1 * robust_std2 / robust_std1
144
- cell2 = np.c_[cell_center2 - cell_size2 / 2, cell_center2 + cell_size2 / 2]
145
-
146
- # make sure cell bounds are valid
147
- cell2 = _norm_windows(cell2, H2, W2, forced_resolution=forced_resolution)
148
-
149
- # compute correspondence weights
150
- corres_weights = _weight_pixels(cell1, p1, assigned) * _weight_pixels(cell2, p2, assigned)
151
-
152
- # return a list of window pairs and assigned correspondences
153
- return cell1, cell2, corres_weights
154
-
155
-
156
- def greedy_selection(corres_weights, target=0.9):
157
- # corres_weight = (n_cell_pair, n_corres) matrix.
158
- # If corres_weight[c,p]>0, means that correspondence p is visible in cell pair p
159
- assert 0 < target <= 1
160
- corres_weights = corres_weights.copy()
161
-
162
- total = corres_weights.max(axis=0).sum()
163
- target *= total
164
-
165
- # init = empty
166
- res = []
167
- cur = np.zeros(corres_weights.shape[1]) # current selection
168
-
169
- while cur.sum() < target:
170
- # pick the nex best cell pair
171
- best = corres_weights.sum(axis=1).argmax()
172
- res.append(best)
173
-
174
- # update current
175
- cur += corres_weights[best]
176
- # print('appending', best, 'with score', corres_weights[best].sum(), '-->', cur.sum())
177
-
178
- # remove from all other views
179
- corres_weights = (corres_weights - corres_weights[best]).clip(min=0)
180
-
181
- return res
182
-
183
-
184
- def select_pairs_of_crops(img_q, img_b, pos2d_in_query, pos2d_in_ref, maxdim=512, overlap=.5, forced_resolution=None):
185
- # prepare the overlapping cells
186
- grid_q = _make_overlapping_grid(*img_q.shape[:2], maxdim, overlap)
187
- grid_b = _make_overlapping_grid(*img_b.shape[:2], maxdim, overlap)
188
-
189
- assert forced_resolution is None or len(forced_resolution) == 2
190
- if isinstance(forced_resolution[0], int) or not len(forced_resolution[0]) == 2:
191
- forced_resolution1 = forced_resolution2 = forced_resolution
192
- else:
193
- assert len(forced_resolution[1]) == 2
194
- forced_resolution1 = forced_resolution[0]
195
- forced_resolution2 = forced_resolution[1]
196
-
197
- # Make sure crops respect constraints
198
- grid_q = _norm_windows(grid_q.astype(float), *img_q.shape[:2], forced_resolution=forced_resolution1)
199
- grid_b = _norm_windows(grid_b.astype(float), *img_b.shape[:2], forced_resolution=forced_resolution2)
200
-
201
- # score cells
202
- pairs_q = _score_cell(grid_q, *img_b.shape[:2], pos2d_in_query, pos2d_in_ref, forced_resolution=forced_resolution2)
203
- pairs_b = _score_cell(grid_b, *img_q.shape[:2], pos2d_in_ref, pos2d_in_query, forced_resolution=forced_resolution1)
204
- pairs_b = pairs_b[1], pairs_b[0], pairs_b[2] # cellq, cellb, corres_weights
205
-
206
- # greedy selection until all correspondences are generated
207
- cell1, cell2, corres_weights = map(np.concatenate, zip(pairs_q, pairs_b))
208
- if len(corres_weights) == 0:
209
- return # tolerated for empty generators
210
- order = greedy_selection(corres_weights, target=0.9)
211
-
212
- for i in order:
213
- def pair_tag(qi, bi): return (str(qi) + crop_tag(cell1[i]), str(bi) + crop_tag(cell2[i]))
214
- yield cell1[i], cell2[i], pair_tag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/utils/collate.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Collate extensions
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import collections
10
- from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format
11
- from typing import Callable, Dict, Optional, Tuple, Type, Union, List
12
-
13
-
14
- def cat_collate_tensor_fn(batch, *, collate_fn_map):
15
- return torch.cat(batch, dim=0)
16
-
17
-
18
- def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
19
- return [item for bb in batch for item in bb] # concatenate all lists
20
-
21
-
22
- cat_collate_fn_map = default_collate_fn_map.copy()
23
- cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn
24
- cat_collate_fn_map[List] = cat_collate_list_fn
25
- cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None
26
-
27
-
28
- def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
29
- r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """
30
- elem = batch[0]
31
- elem_type = type(elem)
32
-
33
- if collate_fn_map is not None:
34
- if elem_type in collate_fn_map:
35
- return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
36
-
37
- for collate_type in collate_fn_map:
38
- if isinstance(elem, collate_type):
39
- return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
40
-
41
- if isinstance(elem, collections.abc.Mapping):
42
- try:
43
- return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
44
- except TypeError:
45
- # The mapping type may not support `__init__(iterable)`.
46
- return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
47
- elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
48
- return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
49
- elif isinstance(elem, collections.abc.Sequence):
50
- transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
51
-
52
- if isinstance(elem, tuple):
53
- # Backwards compatibility.
54
- return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
55
- else:
56
- try:
57
- return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
58
- except TypeError:
59
- # The sequence type may not support `__init__(iterable)` (e.g., `range`).
60
- return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
61
-
62
- raise TypeError(default_collate_err_msg_format.format(elem_type))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/utils/misc.py DELETED
@@ -1,17 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # utilitary functions for MASt3R
6
- # --------------------------------------------------------
7
- import os
8
- import hashlib
9
-
10
-
11
- def mkdir_for(f):
12
- os.makedirs(os.path.dirname(f), exist_ok=True)
13
- return f
14
-
15
-
16
- def hash_md5(s):
17
- return hashlib.md5(s.encode('utf-8')).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mast3r/utils/path_to_dust3r.py DELETED
@@ -1,19 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # dust3r submodule import
6
- # --------------------------------------------------------
7
-
8
- import sys
9
- import os.path as path
10
- HERE_PATH = path.normpath(path.dirname(__file__))
11
- DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../dust3r'))
12
- DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r')
13
- # check the presence of models directory in repo to be sure its cloned
14
- if path.isdir(DUSt3R_LIB_PATH):
15
- # workaround for sibling import
16
- sys.path.insert(0, DUSt3R_REPO_PATH)
17
- else:
18
- raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n "
19
- "Did you forget to run 'git submodule update --init --recursive' ?")