Spaces:
Running
on
Zero
Running
on
Zero
Delete mast3r
Browse files- mast3r/__init__.py +0 -2
- mast3r/catmlp_dpt_head.py +0 -123
- mast3r/cloud_opt/__init__.py +0 -2
- mast3r/cloud_opt/sparse_ga.py +0 -1001
- mast3r/cloud_opt/triangulation.py +0 -80
- mast3r/cloud_opt/tsdf_optimizer.py +0 -273
- mast3r/cloud_opt/utils/__init__.py +0 -2
- mast3r/cloud_opt/utils/losses.py +0 -32
- mast3r/cloud_opt/utils/schedules.py +0 -17
- mast3r/colmap/__init__.py +0 -2
- mast3r/colmap/database.py +0 -383
- mast3r/datasets/__init__.py +0 -62
- mast3r/datasets/base/__init__.py +0 -2
- mast3r/datasets/base/mast3r_base_stereo_view_dataset.py +0 -355
- mast3r/datasets/utils/__init__.py +0 -2
- mast3r/datasets/utils/cropping.py +0 -219
- mast3r/fast_nn.py +0 -221
- mast3r/losses.py +0 -514
- mast3r/model.py +0 -68
- mast3r/utils/__init__.py +0 -2
- mast3r/utils/coarse_to_fine.py +0 -214
- mast3r/utils/collate.py +0 -62
- mast3r/utils/misc.py +0 -17
- mast3r/utils/path_to_dust3r.py +0 -19
mast3r/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/catmlp_dpt_head.py
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# MASt3R heads
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 11 |
-
from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
|
| 12 |
-
from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
|
| 13 |
-
import dust3r.utils.path_to_croco # noqa
|
| 14 |
-
from models.blocks import Mlp # noqa
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def reg_desc(desc, mode):
|
| 18 |
-
if 'norm' in mode:
|
| 19 |
-
desc = desc / desc.norm(dim=-1, keepdim=True)
|
| 20 |
-
else:
|
| 21 |
-
raise ValueError(f"Unknown desc mode {mode}")
|
| 22 |
-
return desc
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
|
| 26 |
-
if desc_conf_mode is None:
|
| 27 |
-
desc_conf_mode = conf_mode
|
| 28 |
-
fmap = out.permute(0, 2, 3, 1) # B,H,W,D
|
| 29 |
-
res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
|
| 30 |
-
if conf_mode is not None:
|
| 31 |
-
res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
|
| 32 |
-
if desc_dim is not None:
|
| 33 |
-
start = 3 + int(conf_mode is not None)
|
| 34 |
-
res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
|
| 35 |
-
if two_confs:
|
| 36 |
-
res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
|
| 37 |
-
else:
|
| 38 |
-
res['desc_conf'] = res['conf'].clone()
|
| 39 |
-
return res
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
|
| 43 |
-
""" Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
|
| 44 |
-
The input for both heads is a concatenation of Encoder and Decoder outputs
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
|
| 48 |
-
num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
|
| 49 |
-
super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
|
| 50 |
-
dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
|
| 51 |
-
self.local_feat_dim = local_feat_dim
|
| 52 |
-
|
| 53 |
-
patch_size = net.patch_embed.patch_size
|
| 54 |
-
if isinstance(patch_size, tuple):
|
| 55 |
-
assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
|
| 56 |
-
patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
|
| 57 |
-
assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
|
| 58 |
-
patch_size = patch_size[0]
|
| 59 |
-
self.patch_size = patch_size
|
| 60 |
-
|
| 61 |
-
self.desc_mode = net.desc_mode
|
| 62 |
-
self.has_conf = has_conf
|
| 63 |
-
self.two_confs = net.two_confs # independent confs for 3D regr and descs
|
| 64 |
-
self.desc_conf_mode = net.desc_conf_mode
|
| 65 |
-
idim = net.enc_embed_dim + net.dec_embed_dim
|
| 66 |
-
|
| 67 |
-
self.head_local_features = Mlp(in_features=idim,
|
| 68 |
-
hidden_features=int(hidden_dim_factor * idim),
|
| 69 |
-
out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
|
| 70 |
-
|
| 71 |
-
def forward(self, decout, img_shape):
|
| 72 |
-
# pass through the heads
|
| 73 |
-
pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
|
| 74 |
-
|
| 75 |
-
# recover encoder and decoder outputs
|
| 76 |
-
enc_output, dec_output = decout[0], decout[-1]
|
| 77 |
-
cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
|
| 78 |
-
H, W = img_shape
|
| 79 |
-
B, S, D = cat_output.shape
|
| 80 |
-
|
| 81 |
-
# extract local_features
|
| 82 |
-
local_features = self.head_local_features(cat_output) # B,S,D
|
| 83 |
-
local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
|
| 84 |
-
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
|
| 85 |
-
|
| 86 |
-
# post process 3D pts, descriptors and confidences
|
| 87 |
-
out = torch.cat([pts3d, local_features], dim=1)
|
| 88 |
-
if self.postprocess:
|
| 89 |
-
out = self.postprocess(out,
|
| 90 |
-
depth_mode=self.depth_mode,
|
| 91 |
-
conf_mode=self.conf_mode,
|
| 92 |
-
desc_dim=self.local_feat_dim,
|
| 93 |
-
desc_mode=self.desc_mode,
|
| 94 |
-
two_confs=self.two_confs,
|
| 95 |
-
desc_conf_mode=self.desc_conf_mode)
|
| 96 |
-
return out
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
|
| 100 |
-
"""" build a prediction head for the decoder
|
| 101 |
-
"""
|
| 102 |
-
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
|
| 103 |
-
local_feat_dim = int(output_mode[10:])
|
| 104 |
-
assert net.dec_depth > 9
|
| 105 |
-
l2 = net.dec_depth
|
| 106 |
-
feature_dim = 256
|
| 107 |
-
last_dim = feature_dim // 2
|
| 108 |
-
out_nchan = 3
|
| 109 |
-
ed = net.enc_embed_dim
|
| 110 |
-
dd = net.dec_embed_dim
|
| 111 |
-
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
|
| 112 |
-
num_channels=out_nchan + has_conf,
|
| 113 |
-
feature_dim=feature_dim,
|
| 114 |
-
last_dim=last_dim,
|
| 115 |
-
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
|
| 116 |
-
dim_tokens=[ed, dd, dd, dd],
|
| 117 |
-
postprocess=postprocess,
|
| 118 |
-
depth_mode=net.depth_mode,
|
| 119 |
-
conf_mode=net.conf_mode,
|
| 120 |
-
head_type='regression')
|
| 121 |
-
else:
|
| 122 |
-
raise NotImplementedError(
|
| 123 |
-
f"unexpected {head_type=} and {output_mode=}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/sparse_ga.py
DELETED
|
@@ -1,1001 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# MASt3R Sparse Global Alignement
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
import roma
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
import numpy as np
|
| 13 |
-
import os
|
| 14 |
-
from collections import namedtuple
|
| 15 |
-
from functools import lru_cache
|
| 16 |
-
from scipy import sparse as sp
|
| 17 |
-
|
| 18 |
-
from mast3r.utils.misc import mkdir_for, hash_md5
|
| 19 |
-
from mast3r.cloud_opt.utils.losses import gamma_loss
|
| 20 |
-
from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
|
| 21 |
-
from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
|
| 22 |
-
|
| 23 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 24 |
-
from dust3r.utils.geometry import inv, geotrf # noqa
|
| 25 |
-
from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
|
| 26 |
-
from dust3r.post_process import estimate_focal_knowing_depth # noqa
|
| 27 |
-
from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
|
| 28 |
-
from dust3r.cloud_opt.base_opt import clean_pointcloud
|
| 29 |
-
from dust3r.viz import SceneViz
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class SparseGA():
|
| 33 |
-
def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
|
| 34 |
-
def fetch_img(im):
|
| 35 |
-
def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
|
| 36 |
-
for im1, im2 in pairs_in:
|
| 37 |
-
if im1['instance'] == im:
|
| 38 |
-
return torgb(im1['img'])
|
| 39 |
-
if im2['instance'] == im:
|
| 40 |
-
return torgb(im2['img'])
|
| 41 |
-
self.canonical_paths = canonical_paths
|
| 42 |
-
self.img_paths = img_paths
|
| 43 |
-
self.imgs = [fetch_img(img) for img in img_paths]
|
| 44 |
-
self.intrinsics = res_fine['intrinsics']
|
| 45 |
-
self.cam2w = res_fine['cam2w']
|
| 46 |
-
self.depthmaps = res_fine['depthmaps']
|
| 47 |
-
self.pts3d = res_fine['pts3d']
|
| 48 |
-
self.pts3d_colors = []
|
| 49 |
-
self.working_device = self.cam2w.device
|
| 50 |
-
for i in range(len(self.imgs)):
|
| 51 |
-
im = self.imgs[i]
|
| 52 |
-
x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
|
| 53 |
-
self.pts3d_colors.append(im[y, x])
|
| 54 |
-
assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
|
| 55 |
-
self.n_imgs = len(self.imgs)
|
| 56 |
-
|
| 57 |
-
def get_focals(self):
|
| 58 |
-
return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
|
| 59 |
-
|
| 60 |
-
def get_principal_points(self):
|
| 61 |
-
return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
|
| 62 |
-
|
| 63 |
-
def get_im_poses(self):
|
| 64 |
-
return self.cam2w
|
| 65 |
-
|
| 66 |
-
def get_sparse_pts3d(self):
|
| 67 |
-
return self.pts3d
|
| 68 |
-
|
| 69 |
-
def get_dense_pts3d(self, clean_depth=True, subsample=8):
|
| 70 |
-
assert self.canonical_paths, 'cache_path is required for dense 3d points'
|
| 71 |
-
device = self.cam2w.device
|
| 72 |
-
confs = []
|
| 73 |
-
base_focals = []
|
| 74 |
-
anchors = {}
|
| 75 |
-
for i, canon_path in enumerate(self.canonical_paths):
|
| 76 |
-
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
| 77 |
-
confs.append(conf)
|
| 78 |
-
base_focals.append(focal)
|
| 79 |
-
|
| 80 |
-
H, W = conf.shape
|
| 81 |
-
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
| 82 |
-
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
| 83 |
-
anchors[i] = (pixels, idxs[i], offsets[i])
|
| 84 |
-
|
| 85 |
-
# densify sparse depthmaps
|
| 86 |
-
pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
|
| 87 |
-
d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
|
| 88 |
-
|
| 89 |
-
if clean_depth:
|
| 90 |
-
confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d)
|
| 91 |
-
|
| 92 |
-
return pts3d, depthmaps, confs
|
| 93 |
-
|
| 94 |
-
def get_pts3d_colors(self):
|
| 95 |
-
return self.pts3d_colors
|
| 96 |
-
|
| 97 |
-
def get_depthmaps(self):
|
| 98 |
-
return self.depthmaps
|
| 99 |
-
|
| 100 |
-
def get_masks(self):
|
| 101 |
-
return [slice(None, None) for _ in range(len(self.imgs))]
|
| 102 |
-
|
| 103 |
-
def show(self, show_cams=True):
|
| 104 |
-
pts3d, _, confs = self.get_dense_pts3d()
|
| 105 |
-
show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
|
| 106 |
-
[p.clip(min=-50, max=50) for p in pts3d],
|
| 107 |
-
masks=[c > 1 for c in confs])
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def convert_dust3r_pairs_naming(imgs, pairs_in):
|
| 111 |
-
for pair_id in range(len(pairs_in)):
|
| 112 |
-
for i in range(2):
|
| 113 |
-
pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
|
| 114 |
-
return pairs_in
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
|
| 118 |
-
device='cuda', dtype=torch.float32, **kw):
|
| 119 |
-
""" Sparse alignment with MASt3R
|
| 120 |
-
imgs: list of image paths
|
| 121 |
-
cache_path: path where to dump temporary files (str)
|
| 122 |
-
|
| 123 |
-
lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
|
| 124 |
-
lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
|
| 125 |
-
|
| 126 |
-
lora_depth: smart dimensionality reduction with depthmaps
|
| 127 |
-
"""
|
| 128 |
-
# Convert pair naming convention from dust3r to mast3r
|
| 129 |
-
pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
|
| 130 |
-
# forward pass
|
| 131 |
-
pairs, cache_path = forward_mast3r(pairs_in, model,
|
| 132 |
-
cache_path=cache_path, subsample=subsample,
|
| 133 |
-
desc_conf=desc_conf, device=device)
|
| 134 |
-
|
| 135 |
-
# extract canonical pointmaps
|
| 136 |
-
tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
|
| 137 |
-
prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
|
| 138 |
-
|
| 139 |
-
# compute minimal spanning tree
|
| 140 |
-
mst = compute_min_spanning_tree(pairwise_scores)
|
| 141 |
-
|
| 142 |
-
# remove all edges not in the spanning tree?
|
| 143 |
-
# min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
|
| 144 |
-
# tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
|
| 145 |
-
|
| 146 |
-
# smartly combine all usefull data
|
| 147 |
-
imsizes, pps, base_focals, core_depth, anchors, corres, corres2d = \
|
| 148 |
-
condense_data(imgs, tmp_pairs, canonical_views, dtype)
|
| 149 |
-
|
| 150 |
-
imgs, res_coarse, res_fine = sparse_scene_optimizer(
|
| 151 |
-
imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths,
|
| 152 |
-
mst, cache_path=cache_path, device=device, dtype=dtype, **kw)
|
| 153 |
-
return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
|
| 157 |
-
preds_21, canonical_paths, mst, cache_path,
|
| 158 |
-
lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
|
| 159 |
-
lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
|
| 160 |
-
lossd=gamma_loss(1.1),
|
| 161 |
-
opt_pp=True, opt_depth=True,
|
| 162 |
-
schedule=cosine_schedule, depth_mode='add', exp_depth=False,
|
| 163 |
-
lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
|
| 164 |
-
init={}, device='cuda', dtype=torch.float32,
|
| 165 |
-
matching_conf_thr=4., loss_dust3r_w=0.01,
|
| 166 |
-
verbose=True, dbg=()):
|
| 167 |
-
|
| 168 |
-
# extrinsic parameters
|
| 169 |
-
vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
|
| 170 |
-
quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
|
| 171 |
-
trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
|
| 172 |
-
|
| 173 |
-
# intialize
|
| 174 |
-
ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
|
| 175 |
-
median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
|
| 176 |
-
for img in imgs:
|
| 177 |
-
idx = imgs.index(img)
|
| 178 |
-
init_values = init.setdefault(img, {})
|
| 179 |
-
if verbose and init_values:
|
| 180 |
-
print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
|
| 181 |
-
|
| 182 |
-
K = init_values.get('intrinsics')
|
| 183 |
-
if K is not None:
|
| 184 |
-
K = K.detach()
|
| 185 |
-
focal = K[:2, :2].diag().mean()
|
| 186 |
-
pp = K[:2, 2]
|
| 187 |
-
base_focals[idx] = focal
|
| 188 |
-
pps[idx] = pp
|
| 189 |
-
pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
|
| 190 |
-
|
| 191 |
-
depth = init_values.get('depthmap')
|
| 192 |
-
if depth is not None:
|
| 193 |
-
core_depth[idx] = depth.detach()
|
| 194 |
-
|
| 195 |
-
median_depths[idx] = med_depth = core_depth[idx].median()
|
| 196 |
-
core_depth[idx] /= med_depth
|
| 197 |
-
|
| 198 |
-
cam2w = init_values.get('cam2w')
|
| 199 |
-
if cam2w is not None:
|
| 200 |
-
rot = cam2w[:3, :3].detach()
|
| 201 |
-
cam_center = cam2w[:3, 3].detach()
|
| 202 |
-
quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
|
| 203 |
-
trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
|
| 204 |
-
trans[idx].data[:] = cam_center + rot @ trans_offset
|
| 205 |
-
del rot
|
| 206 |
-
assert False, 'inverse kinematic chain not yet implemented'
|
| 207 |
-
|
| 208 |
-
# intrinsics parameters
|
| 209 |
-
pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
|
| 210 |
-
diags = imsizes.float().norm(dim=1)
|
| 211 |
-
min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
|
| 212 |
-
max_focals = 10 * diags
|
| 213 |
-
log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
|
| 214 |
-
assert len(mst[1]) == len(pps) - 1
|
| 215 |
-
|
| 216 |
-
def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
|
| 217 |
-
# make intrinsics
|
| 218 |
-
focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
|
| 219 |
-
pps = torch.stack(pps)
|
| 220 |
-
K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
|
| 221 |
-
K[:, 0, 0] = K[:, 1, 1] = focals
|
| 222 |
-
K[:, 0:2, 2] = pps * imsizes
|
| 223 |
-
if trans is None:
|
| 224 |
-
return K
|
| 225 |
-
|
| 226 |
-
# security! optimization is always trying to crush the scale down
|
| 227 |
-
sizes = torch.cat(log_sizes).exp()
|
| 228 |
-
global_scaling = 1 / sizes.min()
|
| 229 |
-
|
| 230 |
-
# compute distance of camera to focal plane
|
| 231 |
-
# tan(fov) = W/2 / focal
|
| 232 |
-
z_cameras = sizes * median_depths * focals / base_focals
|
| 233 |
-
|
| 234 |
-
# make extrinsic
|
| 235 |
-
rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
|
| 236 |
-
rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
|
| 237 |
-
rel_cam2cam[:, :3, 3] = torch.stack(trans)
|
| 238 |
-
|
| 239 |
-
# camera are defined as a kinematic chain
|
| 240 |
-
tmp_cam2w = [None] * len(K)
|
| 241 |
-
tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
|
| 242 |
-
for i, j in mst[1]:
|
| 243 |
-
# i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
|
| 244 |
-
tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
|
| 245 |
-
tmp_cam2w = torch.stack(tmp_cam2w)
|
| 246 |
-
|
| 247 |
-
# smart reparameterizaton of cameras
|
| 248 |
-
trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
|
| 249 |
-
new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
|
| 250 |
-
cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
|
| 251 |
-
vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
|
| 252 |
-
|
| 253 |
-
depthmaps = []
|
| 254 |
-
for i in range(len(imgs)):
|
| 255 |
-
core_depth_img = core_depth[i]
|
| 256 |
-
if exp_depth:
|
| 257 |
-
core_depth_img = core_depth_img.exp()
|
| 258 |
-
if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
|
| 259 |
-
core_depth_img = lora_depth_proj[i] @ core_depth_img
|
| 260 |
-
if depth_mode == 'add':
|
| 261 |
-
core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
|
| 262 |
-
elif depth_mode == 'mul':
|
| 263 |
-
core_depth_img = z_cameras[i] * core_depth_img
|
| 264 |
-
else:
|
| 265 |
-
raise ValueError(f'Bad {depth_mode=}')
|
| 266 |
-
depthmaps.append(global_scaling * core_depth_img)
|
| 267 |
-
|
| 268 |
-
return K, (inv(cam2w), cam2w), depthmaps
|
| 269 |
-
|
| 270 |
-
K = make_K_cam_depth(log_focals, pps, None, None, None, None)
|
| 271 |
-
print('init focals =', to_numpy(K[:, 0, 0]))
|
| 272 |
-
|
| 273 |
-
# spectral low-rank projection of depthmaps
|
| 274 |
-
if lora_depth:
|
| 275 |
-
core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
|
| 276 |
-
imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
|
| 277 |
-
if exp_depth:
|
| 278 |
-
core_depth = [d.clip(min=1e-4).log() for d in core_depth]
|
| 279 |
-
core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
|
| 280 |
-
log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
|
| 281 |
-
|
| 282 |
-
# Fetch img slices
|
| 283 |
-
_, confs_sum, imgs_slices = corres
|
| 284 |
-
|
| 285 |
-
# Define which pairs are fine to use with matching
|
| 286 |
-
def matching_check(x): return x.max() > matching_conf_thr
|
| 287 |
-
is_matching_ok = {}
|
| 288 |
-
for s in imgs_slices:
|
| 289 |
-
is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
|
| 290 |
-
|
| 291 |
-
# Subsample preds_21
|
| 292 |
-
subsamp_preds_21 = {}
|
| 293 |
-
for imk, imv in preds_21.items():
|
| 294 |
-
subsamp_preds_21[imk] = {}
|
| 295 |
-
for im2k, (pred, conf) in preds_21[imk].items():
|
| 296 |
-
subpred = pred[::subsample, ::subsample].reshape(-1, 3) # original subsample
|
| 297 |
-
subconf = conf[::subsample, ::subsample].ravel() # for both ptmaps and confs
|
| 298 |
-
idxs = anchors[imgs.index(im2k)][1]
|
| 299 |
-
subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) # anchors subsample
|
| 300 |
-
|
| 301 |
-
def loss_dust3r(cam2w, pts3d, pix_loss):
|
| 302 |
-
# In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
|
| 303 |
-
loss = 0.
|
| 304 |
-
cf_sum = 0.
|
| 305 |
-
for s in imgs_slices:
|
| 306 |
-
if not is_matching_ok[s.img1, s.img2]:
|
| 307 |
-
# fallback to dust3r regression
|
| 308 |
-
tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
|
| 309 |
-
tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
|
| 310 |
-
cf_sum += tgt_confs.sum()
|
| 311 |
-
loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
|
| 312 |
-
return loss / cf_sum if cf_sum != 0. else 0.
|
| 313 |
-
|
| 314 |
-
def loss_3d(K, w2cam, pts3d, pix_loss):
|
| 315 |
-
# For each correspondence, we have two 3D points (one for each image of the pair).
|
| 316 |
-
# For each 3D point, we have 2 reproj errors
|
| 317 |
-
if any(v.get('freeze') for v in init.values()):
|
| 318 |
-
pts3d_1 = []
|
| 319 |
-
pts3d_2 = []
|
| 320 |
-
confs = []
|
| 321 |
-
for s in imgs_slices:
|
| 322 |
-
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
| 323 |
-
continue
|
| 324 |
-
if is_matching_ok[s.img1, s.img2]:
|
| 325 |
-
pts3d_1.append(pts3d[s.img1][s.slice1])
|
| 326 |
-
pts3d_2.append(pts3d[s.img2][s.slice2])
|
| 327 |
-
confs.append(s.confs)
|
| 328 |
-
else:
|
| 329 |
-
pts3d_1 = [pts3d[s.img1][s.slice1] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
|
| 330 |
-
pts3d_2 = [pts3d[s.img2][s.slice2] for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
|
| 331 |
-
confs = [s.confs for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
|
| 332 |
-
|
| 333 |
-
if pts3d_1 != []:
|
| 334 |
-
confs = torch.cat(confs)
|
| 335 |
-
pts3d_1 = torch.cat(pts3d_1)
|
| 336 |
-
pts3d_2 = torch.cat(pts3d_2)
|
| 337 |
-
loss = confs @ pix_loss(pts3d_1, pts3d_2)
|
| 338 |
-
cf_sum = confs.sum()
|
| 339 |
-
else:
|
| 340 |
-
loss = 0.
|
| 341 |
-
cf_sum = 1.
|
| 342 |
-
|
| 343 |
-
return loss / cf_sum
|
| 344 |
-
|
| 345 |
-
def loss_2d(K, w2cam, pts3d, pix_loss):
|
| 346 |
-
# For each correspondence, we have two 3D points (one for each image of the pair).
|
| 347 |
-
# For each 3D point, we have 2 reproj errors
|
| 348 |
-
proj_matrix = K @ w2cam[:, :3]
|
| 349 |
-
loss = npix = 0
|
| 350 |
-
for img1, pix1, confs, cf_sum, imgs_slices in corres2d:
|
| 351 |
-
if init[imgs[img1]].get('freeze', 0) >= 1:
|
| 352 |
-
continue # no need
|
| 353 |
-
pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in imgs_slices if is_matching_ok[img1, img2]]
|
| 354 |
-
pix1_filtered = []
|
| 355 |
-
confs_filtered = []
|
| 356 |
-
curstep = 0
|
| 357 |
-
for img2, slice2 in imgs_slices:
|
| 358 |
-
if is_matching_ok[img1, img2]:
|
| 359 |
-
tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
|
| 360 |
-
pix1_filtered.append(pix1[tslice])
|
| 361 |
-
confs_filtered.append(confs[tslice])
|
| 362 |
-
curstep += slice2.stop - slice2.start
|
| 363 |
-
if pts3d_in_img1 != []:
|
| 364 |
-
pts3d_in_img1 = torch.cat(pts3d_in_img1)
|
| 365 |
-
pix1_filtered = torch.cat(pix1_filtered)
|
| 366 |
-
confs_filtered = torch.cat(confs_filtered)
|
| 367 |
-
loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
|
| 368 |
-
npix += confs_filtered.sum()
|
| 369 |
-
return loss / npix if npix != 0 else 0.
|
| 370 |
-
|
| 371 |
-
def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
|
| 372 |
-
# create optimizer
|
| 373 |
-
params = pps + log_focals + quats + trans + log_sizes + core_depth
|
| 374 |
-
optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
|
| 375 |
-
ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
|
| 376 |
-
|
| 377 |
-
with tqdm(total=niter) as bar:
|
| 378 |
-
for iter in range(niter or 1):
|
| 379 |
-
K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
|
| 380 |
-
pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
|
| 381 |
-
if niter == 0:
|
| 382 |
-
break
|
| 383 |
-
|
| 384 |
-
alpha = (iter / niter)
|
| 385 |
-
lr = schedule(alpha, lr_base, lr_end)
|
| 386 |
-
adjust_learning_rate_by_lr(optimizer, lr)
|
| 387 |
-
pix_loss = ploss(1 - alpha)
|
| 388 |
-
optimizer.zero_grad()
|
| 389 |
-
loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
|
| 390 |
-
loss.backward()
|
| 391 |
-
optimizer.step()
|
| 392 |
-
|
| 393 |
-
# make sure the pose remains well optimizable
|
| 394 |
-
for i in range(len(imgs)):
|
| 395 |
-
quats[i].data[:] /= quats[i].data.norm()
|
| 396 |
-
|
| 397 |
-
loss = float(loss)
|
| 398 |
-
if loss != loss:
|
| 399 |
-
break # NaN loss
|
| 400 |
-
bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
|
| 401 |
-
bar.update(1)
|
| 402 |
-
|
| 403 |
-
if niter:
|
| 404 |
-
print(f'>> final loss = {loss}')
|
| 405 |
-
return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
|
| 406 |
-
depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
|
| 407 |
-
|
| 408 |
-
# at start, don't optimize 3d points
|
| 409 |
-
for i, img in enumerate(imgs):
|
| 410 |
-
trainable = not (init[img].get('freeze'))
|
| 411 |
-
pps[i].requires_grad_(False)
|
| 412 |
-
log_focals[i].requires_grad_(False)
|
| 413 |
-
quats[i].requires_grad_(trainable)
|
| 414 |
-
trans[i].requires_grad_(trainable)
|
| 415 |
-
log_sizes[i].requires_grad_(trainable)
|
| 416 |
-
core_depth[i].requires_grad_(False)
|
| 417 |
-
|
| 418 |
-
res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
|
| 419 |
-
|
| 420 |
-
res_fine = None
|
| 421 |
-
if niter2:
|
| 422 |
-
# now we can optimize 3d points
|
| 423 |
-
for i, img in enumerate(imgs):
|
| 424 |
-
if init[img].get('freeze', 0) >= 1:
|
| 425 |
-
continue
|
| 426 |
-
pps[i].requires_grad_(bool(opt_pp))
|
| 427 |
-
log_focals[i].requires_grad_(True)
|
| 428 |
-
core_depth[i].requires_grad_(opt_depth)
|
| 429 |
-
|
| 430 |
-
# refinement with 2d reproj
|
| 431 |
-
res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
|
| 432 |
-
|
| 433 |
-
return imgs, res_coarse, res_fine
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
@lru_cache
|
| 437 |
-
def mask110(device, dtype):
|
| 438 |
-
return torch.tensor((1, 1, 0), device=device, dtype=dtype)
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
def proj3d(inv_K, pixels, z):
|
| 442 |
-
if pixels.shape[-1] == 2:
|
| 443 |
-
pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
|
| 444 |
-
return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
|
| 448 |
-
focals = K[:, 0, 0]
|
| 449 |
-
invK = inv(K)
|
| 450 |
-
all_pts3d = []
|
| 451 |
-
depth_out = []
|
| 452 |
-
|
| 453 |
-
for img, (pixels, idxs, offsets) in anchors.items():
|
| 454 |
-
# from depthmaps to 3d points
|
| 455 |
-
if base_focals is None:
|
| 456 |
-
pass
|
| 457 |
-
else:
|
| 458 |
-
# compensate for focal
|
| 459 |
-
# depth + depth * (offset - 1) * base_focal / focal
|
| 460 |
-
# = depth * (1 + (offset - 1) * (base_focal / focal))
|
| 461 |
-
offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
|
| 462 |
-
|
| 463 |
-
pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
|
| 464 |
-
if ret_depth:
|
| 465 |
-
depth_out.append(pts3d[..., 2]) # before camera rotation
|
| 466 |
-
|
| 467 |
-
# rotate to world coordinate
|
| 468 |
-
pts3d = geotrf(cam2w[img], pts3d)
|
| 469 |
-
all_pts3d.append(pts3d)
|
| 470 |
-
|
| 471 |
-
if ret_depth:
|
| 472 |
-
return all_pts3d, depth_out
|
| 473 |
-
return all_pts3d
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
|
| 477 |
-
base_focals = []
|
| 478 |
-
anchors = {}
|
| 479 |
-
confs = []
|
| 480 |
-
for i, canon_path in enumerate(canonical_paths):
|
| 481 |
-
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
| 482 |
-
confs.append(conf)
|
| 483 |
-
base_focals.append(focal)
|
| 484 |
-
H, W = conf.shape
|
| 485 |
-
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
| 486 |
-
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
| 487 |
-
anchors[i] = (pixels, idxs[i], offsets[i])
|
| 488 |
-
|
| 489 |
-
# densify sparse depthmaps
|
| 490 |
-
pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
|
| 491 |
-
d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
|
| 492 |
-
|
| 493 |
-
return pts3d, depthmaps_out, confs
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
@torch.no_grad()
|
| 497 |
-
def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
|
| 498 |
-
device='cuda', subsample=8, **matching_kw):
|
| 499 |
-
res_paths = {}
|
| 500 |
-
|
| 501 |
-
for img1, img2 in tqdm(pairs):
|
| 502 |
-
idx1 = hash_md5(img1['instance'])
|
| 503 |
-
idx2 = hash_md5(img2['instance'])
|
| 504 |
-
|
| 505 |
-
path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
|
| 506 |
-
path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
|
| 507 |
-
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
|
| 508 |
-
path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
|
| 509 |
-
|
| 510 |
-
if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
|
| 511 |
-
score, (xy1, xy2, confs) = torch.load(path_corres2)
|
| 512 |
-
torch.save((score, (xy2, xy1, confs)), path_corres)
|
| 513 |
-
|
| 514 |
-
if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
|
| 515 |
-
if model is None:
|
| 516 |
-
continue
|
| 517 |
-
res = symmetric_inference(model, img1, img2, device=device)
|
| 518 |
-
X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
|
| 519 |
-
C11, C21, C22, C12 = [r['conf'][0] for r in res]
|
| 520 |
-
descs = [r['desc'][0] for r in res]
|
| 521 |
-
qonfs = [r[desc_conf][0] for r in res]
|
| 522 |
-
|
| 523 |
-
# save
|
| 524 |
-
torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
|
| 525 |
-
torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
|
| 526 |
-
|
| 527 |
-
# perform reciprocal matching
|
| 528 |
-
corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
|
| 529 |
-
|
| 530 |
-
conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
|
| 531 |
-
matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
|
| 532 |
-
if cache_path is not None:
|
| 533 |
-
torch.save((matching_score, corres), mkdir_for(path_corres))
|
| 534 |
-
|
| 535 |
-
res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
|
| 536 |
-
|
| 537 |
-
del model
|
| 538 |
-
torch.cuda.empty_cache()
|
| 539 |
-
|
| 540 |
-
return res_paths, cache_path
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
def symmetric_inference(model, img1, img2, device):
|
| 544 |
-
shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
|
| 545 |
-
shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
|
| 546 |
-
img1 = img1['img'].to(device, non_blocking=True)
|
| 547 |
-
img2 = img2['img'].to(device, non_blocking=True)
|
| 548 |
-
|
| 549 |
-
# compute encoder only once
|
| 550 |
-
feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
|
| 551 |
-
|
| 552 |
-
def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
|
| 553 |
-
dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
|
| 554 |
-
with torch.cuda.amp.autocast(enabled=False):
|
| 555 |
-
res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
| 556 |
-
res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
| 557 |
-
return res1, res2
|
| 558 |
-
|
| 559 |
-
# decoder 1-2
|
| 560 |
-
res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
|
| 561 |
-
# decoder 2-1
|
| 562 |
-
res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
|
| 563 |
-
|
| 564 |
-
return (res11, res21, res22, res12)
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
|
| 568 |
-
feat11, feat21, feat22, feat12 = feats
|
| 569 |
-
qonf11, qonf21, qonf22, qonf12 = qonfs
|
| 570 |
-
assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
|
| 571 |
-
assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
|
| 572 |
-
|
| 573 |
-
if '3d' in ptmap_key:
|
| 574 |
-
opt = dict(device='cpu', workers=32)
|
| 575 |
-
else:
|
| 576 |
-
opt = dict(device=device, dist='dot', block_size=2**13)
|
| 577 |
-
|
| 578 |
-
# matching the two pairs
|
| 579 |
-
idx1 = []
|
| 580 |
-
idx2 = []
|
| 581 |
-
qonf1 = []
|
| 582 |
-
qonf2 = []
|
| 583 |
-
# TODO add non symmetric / pixel_tol options
|
| 584 |
-
for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
|
| 585 |
-
(feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
|
| 586 |
-
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 587 |
-
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 588 |
-
|
| 589 |
-
idx1.append(np.r_[nn1to2[0], nn2to1[1]])
|
| 590 |
-
idx2.append(np.r_[nn1to2[1], nn2to1[0]])
|
| 591 |
-
qonf1.append(QA.ravel()[idx1[-1]])
|
| 592 |
-
qonf2.append(QB.ravel()[idx2[-1]])
|
| 593 |
-
|
| 594 |
-
# merge corres from opposite pairs
|
| 595 |
-
H1, W1 = feat11.shape[:2]
|
| 596 |
-
H2, W2 = feat22.shape[:2]
|
| 597 |
-
cat = np.concatenate
|
| 598 |
-
|
| 599 |
-
xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
|
| 600 |
-
corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
|
| 601 |
-
|
| 602 |
-
return todevice(corres, device)
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
@torch.no_grad()
|
| 606 |
-
def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
|
| 607 |
-
cache_path=None, device='cuda', **kw):
|
| 608 |
-
canonical_views = {}
|
| 609 |
-
pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
|
| 610 |
-
canonical_paths = []
|
| 611 |
-
preds_21 = {}
|
| 612 |
-
|
| 613 |
-
for img in tqdm(imgs):
|
| 614 |
-
if cache_path:
|
| 615 |
-
cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
|
| 616 |
-
canonical_paths.append(cache)
|
| 617 |
-
try:
|
| 618 |
-
(canon, canon2, cconf), focal = torch.load(cache, map_location=device)
|
| 619 |
-
except IOError:
|
| 620 |
-
# cache does not exist yet, we create it!
|
| 621 |
-
canon = focal = None
|
| 622 |
-
|
| 623 |
-
# collect all pred1
|
| 624 |
-
n_pairs = sum((img in pair) for pair in tmp_pairs)
|
| 625 |
-
|
| 626 |
-
ptmaps11 = None
|
| 627 |
-
pixels = {}
|
| 628 |
-
n = 0
|
| 629 |
-
for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
|
| 630 |
-
score = None
|
| 631 |
-
if img == img1:
|
| 632 |
-
X, C, X2, C2 = torch.load(path1, map_location=device)
|
| 633 |
-
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
| 634 |
-
pixels[img2] = xy1, confs
|
| 635 |
-
if img not in preds_21:
|
| 636 |
-
preds_21[img] = {}
|
| 637 |
-
preds_21[img][img2] = X2, C2
|
| 638 |
-
|
| 639 |
-
if img == img2:
|
| 640 |
-
X, C, X2, C2 = torch.load(path2, map_location=device)
|
| 641 |
-
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
| 642 |
-
pixels[img1] = xy2, confs
|
| 643 |
-
if img not in preds_21:
|
| 644 |
-
preds_21[img] = {}
|
| 645 |
-
preds_21[img][img1] = X2, C2
|
| 646 |
-
|
| 647 |
-
if score is not None:
|
| 648 |
-
i, j = imgs.index(img1), imgs.index(img2)
|
| 649 |
-
# score = score[0]
|
| 650 |
-
# score = np.log1p(score[2])
|
| 651 |
-
score = score[2]
|
| 652 |
-
pairwise_scores[i, j] = score
|
| 653 |
-
pairwise_scores[j, i] = score
|
| 654 |
-
|
| 655 |
-
if canon is not None:
|
| 656 |
-
continue
|
| 657 |
-
if ptmaps11 is None:
|
| 658 |
-
H, W = C.shape
|
| 659 |
-
ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
|
| 660 |
-
confs11 = torch.empty((n_pairs, H, W), device=device)
|
| 661 |
-
|
| 662 |
-
ptmaps11[n] = X
|
| 663 |
-
confs11[n] = C
|
| 664 |
-
n += 1
|
| 665 |
-
|
| 666 |
-
if canon is None:
|
| 667 |
-
canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
|
| 668 |
-
del ptmaps11
|
| 669 |
-
del confs11
|
| 670 |
-
|
| 671 |
-
# compute focals
|
| 672 |
-
H, W = canon.shape[:2]
|
| 673 |
-
pp = torch.tensor([W / 2, H / 2], device=device)
|
| 674 |
-
if focal is None:
|
| 675 |
-
focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
|
| 676 |
-
if cache:
|
| 677 |
-
torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
|
| 678 |
-
|
| 679 |
-
# extract depth offsets with correspondences
|
| 680 |
-
core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
|
| 681 |
-
idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
|
| 682 |
-
|
| 683 |
-
canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
|
| 684 |
-
|
| 685 |
-
return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
def load_corres(path_corres, device, min_conf_thr):
|
| 689 |
-
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
|
| 690 |
-
valid = confs > min_conf_thr if min_conf_thr else slice(None)
|
| 691 |
-
# valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
|
| 692 |
-
# print(f'keeping {valid.sum()} / {len(valid)} correspondences')
|
| 693 |
-
return score, (xy1[valid], xy2[valid], confs[valid])
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
PairOfSlices = namedtuple(
|
| 697 |
-
'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
def condense_data(imgs, tmp_paths, canonical_views, dtype=torch.float32):
|
| 701 |
-
# aggregate all data properly
|
| 702 |
-
set_imgs = set(imgs)
|
| 703 |
-
|
| 704 |
-
principal_points = []
|
| 705 |
-
shapes = []
|
| 706 |
-
focals = []
|
| 707 |
-
core_depth = []
|
| 708 |
-
img_anchors = {}
|
| 709 |
-
tmp_pixels = {}
|
| 710 |
-
|
| 711 |
-
for idx1, img1 in enumerate(imgs):
|
| 712 |
-
# load stuff
|
| 713 |
-
pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
|
| 714 |
-
|
| 715 |
-
principal_points.append(pp)
|
| 716 |
-
shapes.append(shape)
|
| 717 |
-
focals.append(focal)
|
| 718 |
-
core_depth.append(anchors)
|
| 719 |
-
|
| 720 |
-
img_uv1 = []
|
| 721 |
-
img_idxs = []
|
| 722 |
-
img_offs = []
|
| 723 |
-
cur_n = [0]
|
| 724 |
-
|
| 725 |
-
for img2, (pixels, match_confs) in pixels_confs.items():
|
| 726 |
-
if img2 not in set_imgs:
|
| 727 |
-
continue
|
| 728 |
-
assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
|
| 729 |
-
img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
|
| 730 |
-
img_idxs.append(idxs[img2])
|
| 731 |
-
img_offs.append(offsets[img2])
|
| 732 |
-
cur_n.append(cur_n[-1] + len(pixels))
|
| 733 |
-
# store the position of 3d points
|
| 734 |
-
tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
|
| 735 |
-
img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
|
| 736 |
-
|
| 737 |
-
all_confs = []
|
| 738 |
-
imgs_slices = []
|
| 739 |
-
corres2d = {img: [] for img in range(len(imgs))}
|
| 740 |
-
|
| 741 |
-
for img1, img2 in tmp_paths:
|
| 742 |
-
try:
|
| 743 |
-
pix1, confs1, slice1 = tmp_pixels[img1, img2]
|
| 744 |
-
pix2, confs2, slice2 = tmp_pixels[img2, img1]
|
| 745 |
-
except KeyError:
|
| 746 |
-
continue
|
| 747 |
-
img1 = imgs.index(img1)
|
| 748 |
-
img2 = imgs.index(img2)
|
| 749 |
-
confs = (confs1 * confs2).sqrt()
|
| 750 |
-
|
| 751 |
-
# prepare for loss_3d
|
| 752 |
-
all_confs.append(confs)
|
| 753 |
-
anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
|
| 754 |
-
anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
|
| 755 |
-
imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
|
| 756 |
-
img2, slice2, pix2, anchor_idxs2,
|
| 757 |
-
confs, float(confs.sum())))
|
| 758 |
-
|
| 759 |
-
# prepare for loss_2d
|
| 760 |
-
corres2d[img1].append((pix1, confs, img2, slice2))
|
| 761 |
-
corres2d[img2].append((pix2, confs, img1, slice1))
|
| 762 |
-
|
| 763 |
-
all_confs = torch.cat(all_confs)
|
| 764 |
-
corres = (all_confs, float(all_confs.sum()), imgs_slices)
|
| 765 |
-
|
| 766 |
-
def aggreg_matches(img1, list_matches):
|
| 767 |
-
pix1, confs, img2, slice2 = zip(*list_matches)
|
| 768 |
-
all_pix1 = torch.cat(pix1).to(dtype)
|
| 769 |
-
all_confs = torch.cat(confs).to(dtype)
|
| 770 |
-
return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
|
| 771 |
-
corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
|
| 772 |
-
|
| 773 |
-
imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
|
| 774 |
-
principal_points = torch.stack(principal_points)
|
| 775 |
-
focals = torch.cat(focals)
|
| 776 |
-
return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
|
| 780 |
-
assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
|
| 781 |
-
|
| 782 |
-
# canonical pointmap is just a weighted average
|
| 783 |
-
confs11 = confs11.unsqueeze(-1) - 0.999
|
| 784 |
-
canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
|
| 785 |
-
|
| 786 |
-
canon_depth = ptmaps11[..., 2].unsqueeze(1)
|
| 787 |
-
S = slice(subsample // 2, None, subsample)
|
| 788 |
-
center_depth = canon_depth[:, :, S, S]
|
| 789 |
-
assert (center_depth > 0).all()
|
| 790 |
-
stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
|
| 791 |
-
stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
|
| 792 |
-
|
| 793 |
-
if mode == 'avg-reldepth':
|
| 794 |
-
rel_depth = stacked_depth / center_depth
|
| 795 |
-
stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
|
| 796 |
-
canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
|
| 797 |
-
|
| 798 |
-
elif mode == 'avg-angle':
|
| 799 |
-
xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
|
| 800 |
-
stacked_xy = F.pixel_unshuffle(xy, subsample)
|
| 801 |
-
B, _, H, W = stacked_xy.shape
|
| 802 |
-
stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
|
| 803 |
-
stacked_radius.clip_(min=1e-8)
|
| 804 |
-
|
| 805 |
-
stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
|
| 806 |
-
avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
|
| 807 |
-
|
| 808 |
-
# back to depth
|
| 809 |
-
stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
|
| 810 |
-
|
| 811 |
-
canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
|
| 812 |
-
else:
|
| 813 |
-
raise ValueError(f'bad {mode=}')
|
| 814 |
-
|
| 815 |
-
confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
|
| 816 |
-
return canon, canon2, confs
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
def anchor_depth_offsets(canon_depth, pixels, subsample=8):
|
| 820 |
-
device = canon_depth.device
|
| 821 |
-
|
| 822 |
-
# create a 2D grid of anchor 3D points
|
| 823 |
-
H1, W1 = canon_depth.shape
|
| 824 |
-
yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
|
| 825 |
-
H2, W2 = yx.shape[1:]
|
| 826 |
-
cy, cx = yx.reshape(2, -1)
|
| 827 |
-
core_depth = canon_depth[cy, cx]
|
| 828 |
-
assert (core_depth > 0).all()
|
| 829 |
-
|
| 830 |
-
# slave 3d points (attached to core 3d points)
|
| 831 |
-
core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
|
| 832 |
-
core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
|
| 833 |
-
|
| 834 |
-
for img2, (xy1, _confs) in pixels.items():
|
| 835 |
-
px, py = xy1.long().T
|
| 836 |
-
|
| 837 |
-
# find nearest anchor == block quantization
|
| 838 |
-
core_idx = (py // subsample) * W2 + (px // subsample)
|
| 839 |
-
core_idxs[img2] = core_idx.to(device)
|
| 840 |
-
|
| 841 |
-
# compute relative depth offsets w.r.t. anchors
|
| 842 |
-
ref_z = core_depth[core_idx]
|
| 843 |
-
pts_z = canon_depth[py, px]
|
| 844 |
-
offset = pts_z / ref_z
|
| 845 |
-
core_offs[img2] = offset.detach().to(device)
|
| 846 |
-
|
| 847 |
-
return core_idxs, core_offs
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
def spectral_clustering(graph, k=None, normalized_cuts=False):
|
| 851 |
-
graph.fill_diagonal_(0)
|
| 852 |
-
|
| 853 |
-
# graph laplacian
|
| 854 |
-
degrees = graph.sum(dim=-1)
|
| 855 |
-
laplacian = torch.diag(degrees) - graph
|
| 856 |
-
if normalized_cuts:
|
| 857 |
-
i_inv = torch.diag(degrees.sqrt().reciprocal())
|
| 858 |
-
laplacian = i_inv @ laplacian @ i_inv
|
| 859 |
-
|
| 860 |
-
# compute eigenvectors!
|
| 861 |
-
eigval, eigvec = torch.linalg.eigh(laplacian)
|
| 862 |
-
return eigval[:k], eigvec[:, :k]
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
def sim_func(p1, p2, gamma):
|
| 866 |
-
diff = (p1 - p2).norm(dim=-1)
|
| 867 |
-
avg_depth = (p1[:, :, 2] + p2[:, :, 2])
|
| 868 |
-
rel_distance = diff / avg_depth
|
| 869 |
-
sim = torch.exp(-gamma * rel_distance.square())
|
| 870 |
-
return sim
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
def backproj(K, depthmap, subsample):
|
| 874 |
-
H, W = depthmap.shape
|
| 875 |
-
uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
|
| 876 |
-
xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
|
| 877 |
-
return xyz
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
|
| 881 |
-
normalized_cuts=True, gamma=7, min_norm=5):
|
| 882 |
-
try:
|
| 883 |
-
if cache_path:
|
| 884 |
-
cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
|
| 885 |
-
lora_proj = torch.load(cache_path, map_location=K.device)
|
| 886 |
-
|
| 887 |
-
except IOError:
|
| 888 |
-
# reconstruct 3d points in camera coordinates
|
| 889 |
-
xyz = backproj(K, depthmap, subsample)
|
| 890 |
-
|
| 891 |
-
# compute all distances
|
| 892 |
-
xyz = xyz.reshape(-1, 3)
|
| 893 |
-
graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
|
| 894 |
-
_, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
|
| 895 |
-
|
| 896 |
-
if cache_path:
|
| 897 |
-
torch.save(lora_proj.cpu(), mkdir_for(cache_path))
|
| 898 |
-
|
| 899 |
-
lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
|
| 900 |
-
|
| 901 |
-
# depthmap ~= lora_proj @ coeffs
|
| 902 |
-
return coeffs, lora_proj
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
|
| 906 |
-
# encode the pointmap
|
| 907 |
-
coeffs = torch.linalg.pinv(lora_proj) @ x
|
| 908 |
-
|
| 909 |
-
# rectify the norm of basis vector to be ~ equal
|
| 910 |
-
if coeffs.ndim == 1:
|
| 911 |
-
coeffs = coeffs[:, None]
|
| 912 |
-
if global_norm:
|
| 913 |
-
lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
|
| 914 |
-
elif min_norm:
|
| 915 |
-
lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
|
| 916 |
-
# can have rounding errors here!
|
| 917 |
-
coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
|
| 918 |
-
|
| 919 |
-
return lora_proj.detach(), coeffs.detach()
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
@torch.no_grad()
|
| 923 |
-
def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
|
| 924 |
-
# recover 3d points
|
| 925 |
-
core_depth = []
|
| 926 |
-
lora_proj = []
|
| 927 |
-
|
| 928 |
-
for i, img in enumerate(tqdm(imgs)):
|
| 929 |
-
cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
|
| 930 |
-
depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
|
| 931 |
-
cache_path=cache, **kw)
|
| 932 |
-
core_depth.append(depth)
|
| 933 |
-
lora_proj.append(proj)
|
| 934 |
-
|
| 935 |
-
return core_depth, lora_proj
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
def reproj2d(Trf, pts3d):
|
| 939 |
-
res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
|
| 940 |
-
clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
|
| 941 |
-
uv = res[:, 0:2] / clipped_z
|
| 942 |
-
return uv.clip(min=-1000, max=2000)
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
def bfs(tree, start_node):
|
| 946 |
-
order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
|
| 947 |
-
ranks = np.arange(len(order))
|
| 948 |
-
ranks[order] = ranks.copy()
|
| 949 |
-
return ranks, predecessors
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
def compute_min_spanning_tree(pws):
|
| 953 |
-
sparse_graph = sp.dok_array(pws.shape)
|
| 954 |
-
for i, j in pws.nonzero().cpu().tolist():
|
| 955 |
-
sparse_graph[i, j] = -float(pws[i, j])
|
| 956 |
-
msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
|
| 957 |
-
|
| 958 |
-
# now reorder the oriented edges, starting from the central point
|
| 959 |
-
ranks1, _ = bfs(msp, 0)
|
| 960 |
-
ranks2, _ = bfs(msp, ranks1.argmax())
|
| 961 |
-
ranks1, _ = bfs(msp, ranks2.argmax())
|
| 962 |
-
# this is the point farther from any leaf
|
| 963 |
-
root = np.minimum(ranks1, ranks2).argmax()
|
| 964 |
-
|
| 965 |
-
# find the ordered list of edges that describe the tree
|
| 966 |
-
order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
|
| 967 |
-
order = order[1:] # root not do not have a predecessor
|
| 968 |
-
edges = [(predecessors[i], i) for i in order]
|
| 969 |
-
|
| 970 |
-
return root, edges
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
|
| 974 |
-
viz = SceneViz()
|
| 975 |
-
|
| 976 |
-
cc = cam2w[:, :3, 3]
|
| 977 |
-
cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
|
| 978 |
-
colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
|
| 979 |
-
|
| 980 |
-
if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
|
| 981 |
-
cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
|
| 982 |
-
else:
|
| 983 |
-
imgs = shapes_or_imgs
|
| 984 |
-
cam_kws = dict(images=imgs, cam_size=cs)
|
| 985 |
-
if K is not None:
|
| 986 |
-
viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
|
| 987 |
-
|
| 988 |
-
if gt_cam2w is not None:
|
| 989 |
-
if gt_K is None:
|
| 990 |
-
gt_K = K
|
| 991 |
-
viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
|
| 992 |
-
|
| 993 |
-
if pts3d is not None:
|
| 994 |
-
for i, p in enumerate(pts3d):
|
| 995 |
-
if not len(p):
|
| 996 |
-
continue
|
| 997 |
-
if masks is None:
|
| 998 |
-
viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
|
| 999 |
-
else:
|
| 1000 |
-
viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
|
| 1001 |
-
viz.show(**kw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/triangulation.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# Matches Triangulation Utils
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
# Batched Matches Triangulation
|
| 12 |
-
def batched_triangulate(pts2d, # [B, Ncams, Npts, 2]
|
| 13 |
-
proj_mats): # [B, Ncams, 3, 4] I@E projection matrix
|
| 14 |
-
B, Ncams, Npts, two = pts2d.shape
|
| 15 |
-
assert two==2
|
| 16 |
-
assert proj_mats.shape == (B, Ncams, 3, 4)
|
| 17 |
-
# P - xP
|
| 18 |
-
x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
|
| 19 |
-
y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
|
| 20 |
-
eq = torch.cat([x, y], dim=1).transpose(1, 2) # [B, Npts, 2xNcams, 4]
|
| 21 |
-
return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution
|
| 22 |
-
|
| 23 |
-
def matches_to_depths(intrinsics, # input camera intrinsics [B, Ncams, 3, 3]
|
| 24 |
-
extrinsics, # input camera extrinsics [B, Ncams, 3, 4]
|
| 25 |
-
matches, # input correspondences [B, Ncams, Npts, 2]
|
| 26 |
-
batchsize=16, # bs for batched processing
|
| 27 |
-
min_num_valids_ratio=.3 # at least this ratio of image pairs need to predict a match for a given pixel of img1
|
| 28 |
-
):
|
| 29 |
-
B, Nv, H, W, five = matches.shape
|
| 30 |
-
min_num_valids = np.floor(Nv*min_num_valids_ratio)
|
| 31 |
-
out_aggregated_points, out_depths, out_confs = [], [], []
|
| 32 |
-
for b in range(B//batchsize+1): # batched processing
|
| 33 |
-
start, stop = b*batchsize,min(B,(b+1)*batchsize)
|
| 34 |
-
sub_batch=slice(start,stop)
|
| 35 |
-
sub_batchsize = stop-start
|
| 36 |
-
if sub_batchsize==0:continue
|
| 37 |
-
points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1]
|
| 38 |
-
allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) # [BxNv, 2, HxW, 2]
|
| 39 |
-
|
| 40 |
-
allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:]
|
| 41 |
-
cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] # [B, Nv, 3, 4]
|
| 42 |
-
formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) # [BxNv, 2, 3, 4]
|
| 43 |
-
|
| 44 |
-
# Triangulate matches to 3D
|
| 45 |
-
points_3d_world = batched_triangulate(allpoints, formatted_camPs) # [BxNv, HxW, three]
|
| 46 |
-
|
| 47 |
-
# Aggregate pairwise predictions
|
| 48 |
-
points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3])
|
| 49 |
-
valids = points_3d_world.isfinite()
|
| 50 |
-
valids_sum = valids.sum(dim=-1)
|
| 51 |
-
validsuni=valids_sum.unique()
|
| 52 |
-
assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset"
|
| 53 |
-
confs[valids_sum==0] = 0.
|
| 54 |
-
points_3d_world = points_3d_world*confs[...,None]
|
| 55 |
-
|
| 56 |
-
# Take care of NaNs
|
| 57 |
-
normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1)
|
| 58 |
-
normalization[normalization <= 1e-5] = 1.
|
| 59 |
-
points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1)
|
| 60 |
-
points_3d_world[~valids] = 0.
|
| 61 |
-
aggregated_points = points_3d_world.sum(dim=1) # weighted average (by confidence value) ignoring nans
|
| 62 |
-
|
| 63 |
-
# Reset invalid values to nans, with a min visibility threshold
|
| 64 |
-
aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan
|
| 65 |
-
|
| 66 |
-
# From 3D to depths
|
| 67 |
-
refcamE = extrinsics[sub_batch, 0]
|
| 68 |
-
points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) # [B,HxW,3]
|
| 69 |
-
depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] # [B,H,W]
|
| 70 |
-
|
| 71 |
-
# Cat results
|
| 72 |
-
out_aggregated_points.append(aggregated_points.cpu())
|
| 73 |
-
out_depths.append(depths.cpu())
|
| 74 |
-
out_confs.append(confs.sum(dim=1).cpu())
|
| 75 |
-
|
| 76 |
-
out_aggregated_points = torch.cat(out_aggregated_points,dim=0)
|
| 77 |
-
out_depths = torch.cat(out_depths,dim=0)
|
| 78 |
-
out_confs = torch.cat(out_confs,dim=0)
|
| 79 |
-
|
| 80 |
-
return out_aggregated_points, out_depths, out_confs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/tsdf_optimizer.py
DELETED
|
@@ -1,273 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
import numpy as np
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
from matplotlib import pyplot as pl
|
| 6 |
-
|
| 7 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 8 |
-
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
|
| 9 |
-
from dust3r.cloud_opt.base_opt import clean_pointcloud
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TSDFPostProcess:
|
| 13 |
-
""" Optimizes a signed distance-function to improve depthmaps.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
|
| 17 |
-
self.TSDF_thresh = TSDF_thresh # None -> no TSDF
|
| 18 |
-
self.TSDF_batchsize = TSDF_batchsize
|
| 19 |
-
self.optimizer = optimizer
|
| 20 |
-
|
| 21 |
-
pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
|
| 22 |
-
pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
|
| 23 |
-
self.pts3d = pts3d
|
| 24 |
-
self.depthmaps = depthmaps
|
| 25 |
-
self.confs = confs
|
| 26 |
-
|
| 27 |
-
def _get_depthmaps(self, TSDF_filtering_thresh=None):
|
| 28 |
-
if TSDF_filtering_thresh:
|
| 29 |
-
self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
|
| 30 |
-
dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
|
| 31 |
-
return [d.exp() for d in dms]
|
| 32 |
-
|
| 33 |
-
@torch.no_grad()
|
| 34 |
-
def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
|
| 35 |
-
"""
|
| 36 |
-
Leverage TSDF to post-process estimated depths
|
| 37 |
-
for each pixel, find zero level of TSDF along ray (or closest to 0)
|
| 38 |
-
"""
|
| 39 |
-
print("Post-Processing Depths with TSDF fusion.")
|
| 40 |
-
self.TSDF_im_depthmaps = []
|
| 41 |
-
alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
|
| 42 |
-
), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
|
| 43 |
-
for vi in tqdm(range(self.optimizer.n_imgs)):
|
| 44 |
-
dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
|
| 45 |
-
minvals = torch.full(dm.shape, 1e20)
|
| 46 |
-
|
| 47 |
-
for it in range(niter):
|
| 48 |
-
H, W = dm.shape
|
| 49 |
-
curthresh = (niter - it) * TSDF_filtering_thresh
|
| 50 |
-
dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
|
| 51 |
-
curthresh # decreasing search std along with iterations
|
| 52 |
-
newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
|
| 53 |
-
curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
|
| 54 |
-
imshape])[0] # [H,W,Nsamp,3]
|
| 55 |
-
# Batched TSDF eval
|
| 56 |
-
curproj = curproj.view(-1, 3)
|
| 57 |
-
tsdf_vals = []
|
| 58 |
-
valids = []
|
| 59 |
-
for batch in range(0, len(curproj), self.TSDF_batchsize):
|
| 60 |
-
values, valid = self._TSDF_query(
|
| 61 |
-
curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
|
| 62 |
-
tsdf_vals.append(values)
|
| 63 |
-
valids.append(valid)
|
| 64 |
-
tsdf_vals = torch.cat(tsdf_vals, dim=0)
|
| 65 |
-
valids = torch.cat(valids, dim=0)
|
| 66 |
-
|
| 67 |
-
tsdf_vals = tsdf_vals.view([H, W, nsamples])
|
| 68 |
-
valids = valids.view([H, W, nsamples])
|
| 69 |
-
|
| 70 |
-
# keep depth value that got us the closest to 0
|
| 71 |
-
tsdf_vals[~valids] = torch.inf # ignore invalid values
|
| 72 |
-
tsdf_vals = tsdf_vals.abs()
|
| 73 |
-
mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
|
| 74 |
-
# when all samples live on a very flat zone, do nothing
|
| 75 |
-
allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
|
| 76 |
-
dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
|
| 77 |
-
|
| 78 |
-
# Save refined depth map
|
| 79 |
-
self.TSDF_im_depthmaps.append(dm.log())
|
| 80 |
-
|
| 81 |
-
def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
|
| 82 |
-
"""
|
| 83 |
-
TSDF query call: returns the weighted TSDF value for each query point [N, 3]
|
| 84 |
-
"""
|
| 85 |
-
N, three = qpoints.shape
|
| 86 |
-
assert three == 3
|
| 87 |
-
qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
|
| 88 |
-
# get projection coordinates and depths onto images
|
| 89 |
-
coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
|
| 90 |
-
), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
|
| 91 |
-
image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
|
| 92 |
-
proj_depths = coords_and_depth[..., -1]
|
| 93 |
-
# recover depth values after scene optim
|
| 94 |
-
pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
|
| 95 |
-
# Gather TSDF scores
|
| 96 |
-
all_SDF_scores = pred_depths - proj_depths # SDF
|
| 97 |
-
unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
|
| 98 |
-
# all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
|
| 99 |
-
all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
|
| 100 |
-
# Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
|
| 101 |
-
all_TSDF_weights = (~unseen).float() * valids.float()
|
| 102 |
-
if weighted:
|
| 103 |
-
all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
|
| 104 |
-
# Aggregate all votes, ignoring zeros
|
| 105 |
-
TSDF_weights = all_TSDF_weights.sum(dim=0)
|
| 106 |
-
valids = TSDF_weights != 0.
|
| 107 |
-
TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
|
| 108 |
-
TSDF_wsum[valids] /= TSDF_weights[valids]
|
| 109 |
-
return TSDF_wsum, valids
|
| 110 |
-
|
| 111 |
-
def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
|
| 112 |
-
""" Recover depth value for each input pixel coordinate, along with OOB validity mask
|
| 113 |
-
"""
|
| 114 |
-
B, N, two = image_coords.shape
|
| 115 |
-
assert B == self.optimizer.n_imgs and two == 2
|
| 116 |
-
depths = torch.zeros([B, N], device=image_coords.device)
|
| 117 |
-
valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
|
| 118 |
-
confs = torch.zeros([B, N], device=image_coords.device)
|
| 119 |
-
curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
|
| 120 |
-
for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
|
| 121 |
-
H, W = depth.shape
|
| 122 |
-
valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
|
| 123 |
-
H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
|
| 124 |
-
imc[~valids[ni]] = 0
|
| 125 |
-
depths[ni] = depth[imc[:, 1], imc[:, 0]]
|
| 126 |
-
confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
|
| 127 |
-
return depths, confs, valids
|
| 128 |
-
|
| 129 |
-
def _get_confs_with_normals(self):
|
| 130 |
-
outconfs = []
|
| 131 |
-
# Confidence basedf on depth gradient
|
| 132 |
-
|
| 133 |
-
class Sobel(nn.Module):
|
| 134 |
-
def __init__(self):
|
| 135 |
-
super().__init__()
|
| 136 |
-
self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
|
| 137 |
-
Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
|
| 138 |
-
Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
|
| 139 |
-
G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
|
| 140 |
-
G = G.unsqueeze(1)
|
| 141 |
-
self.filter.weight = nn.Parameter(G, requires_grad=False)
|
| 142 |
-
|
| 143 |
-
def forward(self, img):
|
| 144 |
-
x = self.filter(img)
|
| 145 |
-
x = torch.mul(x, x)
|
| 146 |
-
x = torch.sum(x, dim=1, keepdim=True)
|
| 147 |
-
x = torch.sqrt(x)
|
| 148 |
-
return x
|
| 149 |
-
|
| 150 |
-
grad_op = Sobel().to(self.im_depthmaps[0].device)
|
| 151 |
-
for conf, depth in zip(self.im_conf, self.im_depthmaps):
|
| 152 |
-
grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
|
| 153 |
-
if not 'dbg show':
|
| 154 |
-
pl.imshow(grad_confs.cpu())
|
| 155 |
-
pl.show()
|
| 156 |
-
outconfs.append(conf * grad_confs.to(conf))
|
| 157 |
-
return outconfs
|
| 158 |
-
|
| 159 |
-
def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
|
| 160 |
-
"""
|
| 161 |
-
Projection operation: from 3D points to 2D coordinates + depths
|
| 162 |
-
"""
|
| 163 |
-
B = pts3d.shape[0]
|
| 164 |
-
assert pts3d.shape[0] == cam2worlds.shape[0]
|
| 165 |
-
# prepare Extrinsincs
|
| 166 |
-
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
| 167 |
-
Rinv = R.transpose(-2, -1)
|
| 168 |
-
tinv = -Rinv @ t[..., None]
|
| 169 |
-
|
| 170 |
-
# prepare intrinsics
|
| 171 |
-
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
|
| 172 |
-
if len(focals.shape) == 1:
|
| 173 |
-
focals = torch.stack([focals, focals], dim=-1)
|
| 174 |
-
intrinsics[:, 0, 0] = focals[:, 0]
|
| 175 |
-
intrinsics[:, 1, 1] = focals[:, 1]
|
| 176 |
-
intrinsics[:, :2, -1] = pps
|
| 177 |
-
# Project
|
| 178 |
-
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
| 179 |
-
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
| 180 |
-
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
| 181 |
-
return projpts
|
| 182 |
-
|
| 183 |
-
def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
|
| 184 |
-
in_focals=None, in_pps=None, in_imshapes=None):
|
| 185 |
-
"""
|
| 186 |
-
Backprojection operation: from image depths to 3D points
|
| 187 |
-
"""
|
| 188 |
-
# Get depths and projection params if not provided
|
| 189 |
-
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
| 190 |
-
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
| 191 |
-
depth = self._get_depthmaps() if in_depths is None else in_depths
|
| 192 |
-
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
| 193 |
-
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
| 194 |
-
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
| 195 |
-
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
|
| 196 |
-
|
| 197 |
-
def autoprocess(x):
|
| 198 |
-
x = x[0]
|
| 199 |
-
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
| 200 |
-
return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
|
| 201 |
-
|
| 202 |
-
def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
|
| 203 |
-
"""
|
| 204 |
-
Projection operation: from 3D points to 2D coordinates + depths
|
| 205 |
-
"""
|
| 206 |
-
B = pts3d.shape[0]
|
| 207 |
-
assert pts3d.shape[0] == cam2worlds.shape[0]
|
| 208 |
-
# prepare Extrinsincs
|
| 209 |
-
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
| 210 |
-
Rinv = R.transpose(-2, -1)
|
| 211 |
-
tinv = -Rinv @ t[..., None]
|
| 212 |
-
|
| 213 |
-
# prepare intrinsics
|
| 214 |
-
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
|
| 215 |
-
if len(focals.shape) == 1:
|
| 216 |
-
focals = torch.stack([focals, focals], dim=-1)
|
| 217 |
-
intrinsics[:, 0, 0] = focals[:, 0]
|
| 218 |
-
intrinsics[:, 1, 1] = focals[:, 1]
|
| 219 |
-
intrinsics[:, :2, -1] = pps
|
| 220 |
-
# Project
|
| 221 |
-
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
| 222 |
-
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
| 223 |
-
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
| 224 |
-
return projpts
|
| 225 |
-
|
| 226 |
-
def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
|
| 227 |
-
"""
|
| 228 |
-
Backprojection operation: from image depths to 3D points
|
| 229 |
-
"""
|
| 230 |
-
# Get depths and projection params if not provided
|
| 231 |
-
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
| 232 |
-
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
| 233 |
-
depth = self._get_depthmaps() if in_depths is None else in_depths
|
| 234 |
-
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
| 235 |
-
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
| 236 |
-
|
| 237 |
-
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
| 238 |
-
|
| 239 |
-
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
|
| 240 |
-
|
| 241 |
-
def autoprocess(x):
|
| 242 |
-
x = x[0]
|
| 243 |
-
H, W, three = x.shape[:3]
|
| 244 |
-
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
| 245 |
-
return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
|
| 246 |
-
|
| 247 |
-
def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
|
| 248 |
-
"""
|
| 249 |
-
return 3D points (possibly filtering depths with TSDF)
|
| 250 |
-
"""
|
| 251 |
-
return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
|
| 252 |
-
|
| 253 |
-
def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
|
| 254 |
-
# Setup inner variables
|
| 255 |
-
self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
|
| 256 |
-
self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
|
| 257 |
-
self.im_conf = confs
|
| 258 |
-
|
| 259 |
-
if self.TSDF_thresh > 0.:
|
| 260 |
-
# Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
|
| 261 |
-
self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
|
| 262 |
-
depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
|
| 263 |
-
# Turn them into 3D points
|
| 264 |
-
pts3d = self._backproj_pts3d(in_depths=depthmaps)
|
| 265 |
-
depthmaps = [dd.flatten() for dd in depthmaps]
|
| 266 |
-
pts3d = [pp.view(-1, 3) for pp in pts3d]
|
| 267 |
-
return pts3d, depthmaps
|
| 268 |
-
|
| 269 |
-
def get_dense_pts3d(self, clean_depth=True):
|
| 270 |
-
if clean_depth:
|
| 271 |
-
confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w),
|
| 272 |
-
self.depthmaps, self.pts3d)
|
| 273 |
-
return self.pts3d, self.depthmaps, confs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/utils/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/utils/losses.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# losses for sparse ga
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import torch
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def l05_loss(x, y):
|
| 12 |
-
return torch.linalg.norm(x - y, dim=-1).sqrt()
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def l1_loss(x, y):
|
| 16 |
-
return torch.linalg.norm(x - y, dim=-1)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
|
| 20 |
-
if offset is None:
|
| 21 |
-
if gamma == 1:
|
| 22 |
-
return l1_loss
|
| 23 |
-
# d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
|
| 24 |
-
offset = (1 / gamma)**(1 / (gamma - 1))
|
| 25 |
-
|
| 26 |
-
def loss_func(x, y):
|
| 27 |
-
return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
|
| 28 |
-
return loss_func
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def meta_gamma_loss():
|
| 32 |
-
return lambda alpha: gamma_loss(alpha)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/cloud_opt/utils/schedules.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# lr schedules for sparse ga
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def linear_schedule(alpha, lr_base, lr_end=0):
|
| 11 |
-
lr = (1 - alpha) * lr_base + alpha * lr_end
|
| 12 |
-
return lr
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def cosine_schedule(alpha, lr_base, lr_end=0):
|
| 16 |
-
lr = lr_end + (lr_base - lr_end) * (1 + np.cos(alpha * np.pi)) / 2
|
| 17 |
-
return lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/colmap/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/colmap/database.py
DELETED
|
@@ -1,383 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# MASt3R to colmap export functions
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import os
|
| 8 |
-
import torch
|
| 9 |
-
import copy
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torchvision
|
| 12 |
-
import numpy as np
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
from scipy.cluster.hierarchy import DisjointSet
|
| 15 |
-
from scipy.spatial.transform import Rotation as R
|
| 16 |
-
|
| 17 |
-
from mast3r.utils.misc import hash_md5
|
| 18 |
-
|
| 19 |
-
from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns
|
| 20 |
-
|
| 21 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 22 |
-
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid # noqa
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz):
|
| 26 |
-
if viz:
|
| 27 |
-
from matplotlib import pyplot as pl
|
| 28 |
-
|
| 29 |
-
image_mean = torch.as_tensor(
|
| 30 |
-
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
|
| 31 |
-
image_std = torch.as_tensor(
|
| 32 |
-
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
|
| 33 |
-
rgb0 = img0['img'] * image_std + image_mean
|
| 34 |
-
rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0])
|
| 35 |
-
rgb0 = np.array(rgb0)
|
| 36 |
-
|
| 37 |
-
rgb1 = img1['img'] * image_std + image_mean
|
| 38 |
-
rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0])
|
| 39 |
-
rgb1 = np.array(rgb1)
|
| 40 |
-
|
| 41 |
-
imgs = [rgb0, rgb1]
|
| 42 |
-
# visualize a few matches
|
| 43 |
-
n_viz = 100
|
| 44 |
-
num_matches = matches_im0.shape[0]
|
| 45 |
-
match_idx_to_viz = np.round(np.linspace(
|
| 46 |
-
0, num_matches - 1, n_viz)).astype(int)
|
| 47 |
-
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
|
| 48 |
-
|
| 49 |
-
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
|
| 50 |
-
rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)),
|
| 51 |
-
(0, 0), (0, 0)), 'constant', constant_values=0)
|
| 52 |
-
rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)),
|
| 53 |
-
(0, 0), (0, 0)), 'constant', constant_values=0)
|
| 54 |
-
img = np.concatenate((rgb0, rgb1), axis=1)
|
| 55 |
-
pl.figure()
|
| 56 |
-
pl.imshow(img)
|
| 57 |
-
cmap = pl.get_cmap('jet')
|
| 58 |
-
for ii in range(n_viz):
|
| 59 |
-
(x0, y0), (x1,
|
| 60 |
-
y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T
|
| 61 |
-
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii /
|
| 62 |
-
(n_viz - 1)), scalex=False, scaley=False)
|
| 63 |
-
pl.show(block=True)
|
| 64 |
-
|
| 65 |
-
matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)]
|
| 66 |
-
imgs = [img0, img1]
|
| 67 |
-
imidx0 = img0['idx']
|
| 68 |
-
imidx1 = img1['idx']
|
| 69 |
-
ravel_matches = []
|
| 70 |
-
for j in range(2):
|
| 71 |
-
H, W = imgs[j]['true_shape'][0]
|
| 72 |
-
with np.errstate(invalid='ignore'):
|
| 73 |
-
qx, qy = matches[j].round().astype(np.int32).T
|
| 74 |
-
ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
|
| 75 |
-
ravel_matches.append(ravel_matches_j)
|
| 76 |
-
imidxj = imgs[j]['idx']
|
| 77 |
-
for m in ravel_matches_j:
|
| 78 |
-
if m not in im_keypoints[imidxj]:
|
| 79 |
-
im_keypoints[imidxj][m] = 0
|
| 80 |
-
im_keypoints[imidxj][m] += 1
|
| 81 |
-
imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid'])
|
| 82 |
-
imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid'])
|
| 83 |
-
if imid0 > imid1:
|
| 84 |
-
colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1)
|
| 85 |
-
imid0, imid1 = imid1, imid0
|
| 86 |
-
imidx0, imidx1 = imidx1, imidx0
|
| 87 |
-
else:
|
| 88 |
-
colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1)
|
| 89 |
-
colmap_matches = np.unique(colmap_matches, axis=0)
|
| 90 |
-
return imidx0, imidx1, colmap_matches
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr,
|
| 94 |
-
is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'):
|
| 95 |
-
im_matches = {}
|
| 96 |
-
for i in range(len(pred1['pts3d'])):
|
| 97 |
-
imidx0 = pairs[i][0]['idx']
|
| 98 |
-
imidx1 = pairs[i][1]['idx']
|
| 99 |
-
if 'desc' in pred1: # mast3r
|
| 100 |
-
descs = [pred1['desc'][i], pred2['desc'][i]]
|
| 101 |
-
confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]]
|
| 102 |
-
desc_dim = descs[0].shape[-1]
|
| 103 |
-
|
| 104 |
-
if is_sparse:
|
| 105 |
-
corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1],
|
| 106 |
-
device=device, subsample=subsample, pixel_tol=pixel_tol)
|
| 107 |
-
conf = corres[2]
|
| 108 |
-
mask = conf >= conf_thr
|
| 109 |
-
matches_im0 = corres[0][mask].cpu().numpy()
|
| 110 |
-
matches_im1 = corres[1][mask].cpu().numpy()
|
| 111 |
-
else:
|
| 112 |
-
confidence_masks = [confidences[0] >=
|
| 113 |
-
conf_thr, confidences[1] >= conf_thr]
|
| 114 |
-
pts2d_list, desc_list = [], []
|
| 115 |
-
for j in range(2):
|
| 116 |
-
conf_j = confidence_masks[j].cpu().numpy().flatten()
|
| 117 |
-
true_shape_j = pairs[i][j]['true_shape'][0]
|
| 118 |
-
pts2d_j = xy_grid(
|
| 119 |
-
true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
|
| 120 |
-
desc_j = descs[j].detach().cpu(
|
| 121 |
-
).numpy().reshape(-1, desc_dim)[conf_j]
|
| 122 |
-
pts2d_list.append(pts2d_j)
|
| 123 |
-
desc_list.append(desc_j)
|
| 124 |
-
if len(desc_list[0]) == 0 or len(desc_list[1]) == 0:
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1],
|
| 128 |
-
device=device, dist='dot', block_size=2**13)
|
| 129 |
-
reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0)))
|
| 130 |
-
|
| 131 |
-
matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0]
|
| 132 |
-
matches_im0 = pts2d_list[0][reciprocal_in_P0]
|
| 133 |
-
else:
|
| 134 |
-
pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]]
|
| 135 |
-
confidences = [pred1['conf'][i], pred2['conf'][i]]
|
| 136 |
-
|
| 137 |
-
if is_sparse:
|
| 138 |
-
corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1],
|
| 139 |
-
device=device, subsample=subsample, pixel_tol=pixel_tol,
|
| 140 |
-
ptmap_key='3d')
|
| 141 |
-
conf = corres[2]
|
| 142 |
-
mask = conf >= conf_thr
|
| 143 |
-
matches_im0 = corres[0][mask].cpu().numpy()
|
| 144 |
-
matches_im1 = corres[1][mask].cpu().numpy()
|
| 145 |
-
else:
|
| 146 |
-
confidence_masks = [confidences[0] >=
|
| 147 |
-
conf_thr, confidences[1] >= conf_thr]
|
| 148 |
-
# find 2D-2D matches between the two images
|
| 149 |
-
pts2d_list, pts3d_list = [], []
|
| 150 |
-
for j in range(2):
|
| 151 |
-
conf_j = confidence_masks[j].cpu().numpy().flatten()
|
| 152 |
-
true_shape_j = pairs[i][j]['true_shape'][0]
|
| 153 |
-
pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
|
| 154 |
-
pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j]
|
| 155 |
-
pts2d_list.append(pts2d_j)
|
| 156 |
-
pts3d_list.append(pts3d_j)
|
| 157 |
-
|
| 158 |
-
PQ, PM = pts3d_list[0], pts3d_list[1]
|
| 159 |
-
if len(PQ) == 0 or len(PM) == 0:
|
| 160 |
-
continue
|
| 161 |
-
reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(
|
| 162 |
-
PQ, PM)
|
| 163 |
-
|
| 164 |
-
matches_im1 = pts2d_list[1][reciprocal_in_PM]
|
| 165 |
-
matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM]
|
| 166 |
-
|
| 167 |
-
if len(matches_im0) == 0:
|
| 168 |
-
continue
|
| 169 |
-
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
|
| 170 |
-
image_to_colmap, im_keypoints,
|
| 171 |
-
matches_im0, matches_im1, viz)
|
| 172 |
-
im_matches[(imidx0, imidx1)] = colmap_matches
|
| 173 |
-
return im_matches
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample,
|
| 177 |
-
image_to_colmap, im_keypoints, conf_thr,
|
| 178 |
-
viz=False, device='cuda'):
|
| 179 |
-
im_matches = {}
|
| 180 |
-
for i in range(len(pairs)):
|
| 181 |
-
imidx0 = pairs[i][0]['idx']
|
| 182 |
-
imidx1 = pairs[i][1]['idx']
|
| 183 |
-
|
| 184 |
-
corres_idx1 = hash_md5(pairs[i][0]['instance'])
|
| 185 |
-
corres_idx2 = hash_md5(pairs[i][1]['instance'])
|
| 186 |
-
|
| 187 |
-
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth'
|
| 188 |
-
if os.path.isfile(path_corres):
|
| 189 |
-
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
|
| 190 |
-
else:
|
| 191 |
-
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth'
|
| 192 |
-
score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device)
|
| 193 |
-
mask = confs >= conf_thr
|
| 194 |
-
matches_im0 = xy1[mask].cpu().numpy()
|
| 195 |
-
matches_im1 = xy2[mask].cpu().numpy()
|
| 196 |
-
|
| 197 |
-
if len(matches_im0) == 0:
|
| 198 |
-
continue
|
| 199 |
-
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
|
| 200 |
-
image_to_colmap, im_keypoints,
|
| 201 |
-
matches_im0, matches_im1, viz)
|
| 202 |
-
im_matches[(imidx0, imidx1)] = colmap_matches
|
| 203 |
-
return im_matches
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model):
|
| 207 |
-
# add cameras/images to the db
|
| 208 |
-
# with the output of ga as prior
|
| 209 |
-
image_to_colmap = {}
|
| 210 |
-
im_keypoints = {}
|
| 211 |
-
for idx in range(len(image_paths)):
|
| 212 |
-
im_keypoints[idx] = {}
|
| 213 |
-
H, W = images[idx]["orig_shape"]
|
| 214 |
-
if focals is None:
|
| 215 |
-
focal_x = focal_y = 1.2 * max(W, H)
|
| 216 |
-
prior_focal_length = False
|
| 217 |
-
cx = W / 2.0
|
| 218 |
-
cy = H / 2.0
|
| 219 |
-
elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2:
|
| 220 |
-
# intrinsics
|
| 221 |
-
focal_x = focals[idx][0, 0]
|
| 222 |
-
focal_y = focals[idx][1, 1]
|
| 223 |
-
cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0]
|
| 224 |
-
cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1]
|
| 225 |
-
prior_focal_length = True
|
| 226 |
-
else:
|
| 227 |
-
focal_x = focal_y = float(focals[idx])
|
| 228 |
-
prior_focal_length = True
|
| 229 |
-
cx = W / 2.0
|
| 230 |
-
cy = H / 2.0
|
| 231 |
-
focal_x = focal_x * images[idx]["to_orig"][0, 0]
|
| 232 |
-
focal_y = focal_y * images[idx]["to_orig"][1, 1]
|
| 233 |
-
|
| 234 |
-
if camera_model == "SIMPLE_PINHOLE":
|
| 235 |
-
model_id = 0
|
| 236 |
-
focal = (focal_x + focal_y) / 2.0
|
| 237 |
-
params = np.asarray([focal, cx, cy], np.float64)
|
| 238 |
-
elif camera_model == "PINHOLE":
|
| 239 |
-
model_id = 1
|
| 240 |
-
params = np.asarray([focal_x, focal_y, cx, cy], np.float64)
|
| 241 |
-
elif camera_model == "SIMPLE_RADIAL":
|
| 242 |
-
model_id = 2
|
| 243 |
-
focal = (focal_x + focal_y) / 2.0
|
| 244 |
-
params = np.asarray([focal, cx, cy, 0.0], np.float64)
|
| 245 |
-
elif camera_model == "OPENCV":
|
| 246 |
-
model_id = 4
|
| 247 |
-
params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64)
|
| 248 |
-
else:
|
| 249 |
-
raise ValueError(f"invalid camera model {camera_model}")
|
| 250 |
-
|
| 251 |
-
H, W = int(H), int(W)
|
| 252 |
-
# OPENCV camera model
|
| 253 |
-
camid = db.add_camera(
|
| 254 |
-
model_id, W, H, params, prior_focal_length=prior_focal_length)
|
| 255 |
-
if ga_world_to_cam is None:
|
| 256 |
-
prior_t = np.zeros(3)
|
| 257 |
-
prior_q = np.zeros(4)
|
| 258 |
-
else:
|
| 259 |
-
q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat()
|
| 260 |
-
prior_t = ga_world_to_cam[idx][:3, 3]
|
| 261 |
-
prior_q = np.array([q[-1], q[0], q[1], q[2]])
|
| 262 |
-
imid = db.add_image(
|
| 263 |
-
image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t)
|
| 264 |
-
image_to_colmap[idx] = {
|
| 265 |
-
'colmap_imid': imid,
|
| 266 |
-
'colmap_camid': camid
|
| 267 |
-
}
|
| 268 |
-
return image_to_colmap, im_keypoints
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification):
|
| 272 |
-
colmap_image_pairs = []
|
| 273 |
-
# 2D-2D are quite dense
|
| 274 |
-
# we want to remove the very small tracks
|
| 275 |
-
# and export only kpt for which we have values
|
| 276 |
-
# build tracks
|
| 277 |
-
print("building tracks")
|
| 278 |
-
keypoints_to_track_id = {}
|
| 279 |
-
track_id_to_kpt_list = []
|
| 280 |
-
to_merge = []
|
| 281 |
-
for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()):
|
| 282 |
-
if imidx0 not in keypoints_to_track_id:
|
| 283 |
-
keypoints_to_track_id[imidx0] = {}
|
| 284 |
-
if imidx1 not in keypoints_to_track_id:
|
| 285 |
-
keypoints_to_track_id[imidx1] = {}
|
| 286 |
-
|
| 287 |
-
for m in colmap_matches:
|
| 288 |
-
if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]:
|
| 289 |
-
# new pair of kpts never seen before
|
| 290 |
-
track_idx = len(track_id_to_kpt_list)
|
| 291 |
-
keypoints_to_track_id[imidx0][m[0]] = track_idx
|
| 292 |
-
keypoints_to_track_id[imidx1][m[1]] = track_idx
|
| 293 |
-
track_id_to_kpt_list.append(
|
| 294 |
-
[(imidx0, m[0]), (imidx1, m[1])])
|
| 295 |
-
elif m[1] not in keypoints_to_track_id[imidx1]:
|
| 296 |
-
# 0 has a track, not 1
|
| 297 |
-
track_idx = keypoints_to_track_id[imidx0][m[0]]
|
| 298 |
-
keypoints_to_track_id[imidx1][m[1]] = track_idx
|
| 299 |
-
track_id_to_kpt_list[track_idx].append((imidx1, m[1]))
|
| 300 |
-
elif m[0] not in keypoints_to_track_id[imidx0]:
|
| 301 |
-
# 1 has a track, not 0
|
| 302 |
-
track_idx = keypoints_to_track_id[imidx1][m[1]]
|
| 303 |
-
keypoints_to_track_id[imidx0][m[0]] = track_idx
|
| 304 |
-
track_id_to_kpt_list[track_idx].append((imidx0, m[0]))
|
| 305 |
-
else:
|
| 306 |
-
# both have tracks, merge them
|
| 307 |
-
track_idx0 = keypoints_to_track_id[imidx0][m[0]]
|
| 308 |
-
track_idx1 = keypoints_to_track_id[imidx1][m[1]]
|
| 309 |
-
if track_idx0 != track_idx1:
|
| 310 |
-
# let's deal with them later
|
| 311 |
-
to_merge.append((track_idx0, track_idx1))
|
| 312 |
-
|
| 313 |
-
# regroup merge targets
|
| 314 |
-
print("merging tracks")
|
| 315 |
-
unique = np.unique(to_merge)
|
| 316 |
-
tree = DisjointSet(unique)
|
| 317 |
-
for track_idx0, track_idx1 in tqdm(to_merge):
|
| 318 |
-
tree.merge(track_idx0, track_idx1)
|
| 319 |
-
|
| 320 |
-
subsets = tree.subsets()
|
| 321 |
-
print("applying merge")
|
| 322 |
-
for setvals in tqdm(subsets):
|
| 323 |
-
new_trackid = len(track_id_to_kpt_list)
|
| 324 |
-
kpt_list = []
|
| 325 |
-
for track_idx in setvals:
|
| 326 |
-
kpt_list.extend(track_id_to_kpt_list[track_idx])
|
| 327 |
-
for imidx, kpid in track_id_to_kpt_list[track_idx]:
|
| 328 |
-
keypoints_to_track_id[imidx][kpid] = new_trackid
|
| 329 |
-
track_id_to_kpt_list.append(kpt_list)
|
| 330 |
-
|
| 331 |
-
# binc = np.bincount([len(v) for v in track_id_to_kpt_list])
|
| 332 |
-
# nonzero = np.nonzero(binc)
|
| 333 |
-
# nonzerobinc = binc[nonzero[0]]
|
| 334 |
-
# print(nonzero[0].tolist())
|
| 335 |
-
# print(nonzerobinc)
|
| 336 |
-
num_valid_tracks = sum(
|
| 337 |
-
[1 for v in track_id_to_kpt_list if len(v) >= min_len_track])
|
| 338 |
-
|
| 339 |
-
keypoints_to_idx = {}
|
| 340 |
-
print(f"squashing keypoints - {num_valid_tracks} valid tracks")
|
| 341 |
-
for imidx, keypoints_imid in tqdm(im_keypoints.items()):
|
| 342 |
-
imid = image_to_colmap[imidx]['colmap_imid']
|
| 343 |
-
keypoints_kept = []
|
| 344 |
-
keypoints_to_idx[imidx] = {}
|
| 345 |
-
for kp in keypoints_imid.keys():
|
| 346 |
-
if kp not in keypoints_to_track_id[imidx]:
|
| 347 |
-
continue
|
| 348 |
-
track_idx = keypoints_to_track_id[imidx][kp]
|
| 349 |
-
track_length = len(track_id_to_kpt_list[track_idx])
|
| 350 |
-
if track_length < min_len_track:
|
| 351 |
-
continue
|
| 352 |
-
keypoints_to_idx[imidx][kp] = len(keypoints_kept)
|
| 353 |
-
keypoints_kept.append(kp)
|
| 354 |
-
if len(keypoints_kept) == 0:
|
| 355 |
-
continue
|
| 356 |
-
keypoints_kept = np.array(keypoints_kept)
|
| 357 |
-
keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[
|
| 358 |
-
0].base[:, ::-1].copy().astype(np.float32)
|
| 359 |
-
# rescale coordinates
|
| 360 |
-
keypoints_kept[:, 0] += 0.5
|
| 361 |
-
keypoints_kept[:, 1] += 0.5
|
| 362 |
-
keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True)
|
| 363 |
-
|
| 364 |
-
H, W = images[imidx]['orig_shape']
|
| 365 |
-
keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01)
|
| 366 |
-
keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01)
|
| 367 |
-
|
| 368 |
-
db.add_keypoints(imid, keypoints_kept)
|
| 369 |
-
|
| 370 |
-
print("exporting im_matches")
|
| 371 |
-
for (imidx0, imidx1), colmap_matches in im_matches.items():
|
| 372 |
-
imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid']
|
| 373 |
-
assert imid0 < imid1
|
| 374 |
-
final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]]
|
| 375 |
-
for m in colmap_matches
|
| 376 |
-
if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]])
|
| 377 |
-
if len(final_matches) > 0:
|
| 378 |
-
colmap_image_pairs.append(
|
| 379 |
-
(images[imidx0]['instance'], images[imidx1]['instance']))
|
| 380 |
-
db.add_matches(imid0, imid1, final_matches)
|
| 381 |
-
if skip_geometric_verification:
|
| 382 |
-
db.add_two_view_geometry(imid0, imid1, final_matches)
|
| 383 |
-
return colmap_image_pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/datasets/__init__.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
|
| 4 |
-
from .base.mast3r_base_stereo_view_dataset import MASt3RBaseStereoViewDataset
|
| 5 |
-
|
| 6 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 7 |
-
from dust3r.datasets.arkitscenes import ARKitScenes as DUSt3R_ARKitScenes # noqa
|
| 8 |
-
from dust3r.datasets.blendedmvs import BlendedMVS as DUSt3R_BlendedMVS # noqa
|
| 9 |
-
from dust3r.datasets.co3d import Co3d as DUSt3R_Co3d # noqa
|
| 10 |
-
from dust3r.datasets.megadepth import MegaDepth as DUSt3R_MegaDepth # noqa
|
| 11 |
-
from dust3r.datasets.scannetpp import ScanNetpp as DUSt3R_ScanNetpp # noqa
|
| 12 |
-
from dust3r.datasets.staticthings3d import StaticThings3D as DUSt3R_StaticThings3D # noqa
|
| 13 |
-
from dust3r.datasets.waymo import Waymo as DUSt3R_Waymo # noqa
|
| 14 |
-
from dust3r.datasets.wildrgbd import WildRGBD as DUSt3R_WildRGBD # noqa
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class ARKitScenes(DUSt3R_ARKitScenes, MASt3RBaseStereoViewDataset):
|
| 18 |
-
def __init__(self, *args, split, ROOT, **kwargs):
|
| 19 |
-
super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
|
| 20 |
-
self.is_metric_scale = True
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class BlendedMVS(DUSt3R_BlendedMVS, MASt3RBaseStereoViewDataset):
|
| 24 |
-
def __init__(self, *args, ROOT, split=None, **kwargs):
|
| 25 |
-
super().__init__(*args, ROOT=ROOT, split=split, **kwargs)
|
| 26 |
-
self.is_metric_scale = False
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class Co3d(DUSt3R_Co3d, MASt3RBaseStereoViewDataset):
|
| 30 |
-
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
| 31 |
-
super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
|
| 32 |
-
self.is_metric_scale = False
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class MegaDepth(DUSt3R_MegaDepth, MASt3RBaseStereoViewDataset):
|
| 36 |
-
def __init__(self, *args, split, ROOT, **kwargs):
|
| 37 |
-
super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
|
| 38 |
-
self.is_metric_scale = False
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class ScanNetpp(DUSt3R_ScanNetpp, MASt3RBaseStereoViewDataset):
|
| 42 |
-
def __init__(self, *args, ROOT, **kwargs):
|
| 43 |
-
super().__init__(*args, ROOT=ROOT, **kwargs)
|
| 44 |
-
self.is_metric_scale = True
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class StaticThings3D(DUSt3R_StaticThings3D, MASt3RBaseStereoViewDataset):
|
| 48 |
-
def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):
|
| 49 |
-
super().__init__(ROOT, *args, mask_bg=mask_bg, **kwargs)
|
| 50 |
-
self.is_metric_scale = False
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class Waymo(DUSt3R_Waymo, MASt3RBaseStereoViewDataset):
|
| 54 |
-
def __init__(self, *args, ROOT, **kwargs):
|
| 55 |
-
super().__init__(*args, ROOT=ROOT, **kwargs)
|
| 56 |
-
self.is_metric_scale = True
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class WildRGBD(DUSt3R_WildRGBD, MASt3RBaseStereoViewDataset):
|
| 60 |
-
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
| 61 |
-
super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
|
| 62 |
-
self.is_metric_scale = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/datasets/base/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/datasets/base/mast3r_base_stereo_view_dataset.py
DELETED
|
@@ -1,355 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# base class for implementing datasets
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import PIL.Image
|
| 8 |
-
import PIL.Image as Image
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
import copy
|
| 12 |
-
|
| 13 |
-
from mast3r.datasets.utils.cropping import (extract_correspondences_from_pts3d,
|
| 14 |
-
gen_random_crops, in2d_rect, crop_to_homography)
|
| 15 |
-
|
| 16 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 17 |
-
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset, view_name, is_good_type # noqa
|
| 18 |
-
from dust3r.datasets.utils.transforms import ImgNorm
|
| 19 |
-
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, depthmap_to_camera_coordinates
|
| 20 |
-
import dust3r.datasets.utils.cropping as cropping
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class MASt3RBaseStereoViewDataset(BaseStereoViewDataset):
|
| 24 |
-
def __init__(self, *, # only keyword arguments
|
| 25 |
-
split=None,
|
| 26 |
-
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 27 |
-
transform=ImgNorm,
|
| 28 |
-
aug_crop=False,
|
| 29 |
-
aug_swap=False,
|
| 30 |
-
aug_monocular=False,
|
| 31 |
-
aug_portrait_or_landscape=True, # automatic choice between landscape/portrait when possible
|
| 32 |
-
aug_rot90=False,
|
| 33 |
-
n_corres=0,
|
| 34 |
-
nneg=0,
|
| 35 |
-
n_tentative_crops=4,
|
| 36 |
-
seed=None):
|
| 37 |
-
super().__init__(split=split, resolution=resolution, transform=transform, aug_crop=aug_crop, seed=seed)
|
| 38 |
-
self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
|
| 39 |
-
|
| 40 |
-
self.aug_swap = aug_swap
|
| 41 |
-
self.aug_monocular = aug_monocular
|
| 42 |
-
self.aug_portrait_or_landscape = aug_portrait_or_landscape
|
| 43 |
-
self.aug_rot90 = aug_rot90
|
| 44 |
-
|
| 45 |
-
self.n_corres = n_corres
|
| 46 |
-
self.nneg = nneg
|
| 47 |
-
assert self.n_corres == 'all' or isinstance(self.n_corres, int) or (isinstance(self.n_corres, list) and len(
|
| 48 |
-
self.n_corres) == self.num_views), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
|
| 49 |
-
assert self.nneg == 0 or self.n_corres != 'all'
|
| 50 |
-
self.n_tentative_crops = n_tentative_crops
|
| 51 |
-
|
| 52 |
-
def _swap_view_aug(self, views):
|
| 53 |
-
if self._rng.random() < 0.5:
|
| 54 |
-
views.reverse()
|
| 55 |
-
|
| 56 |
-
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
|
| 57 |
-
""" This function:
|
| 58 |
-
- first downsizes the image with LANCZOS inteprolation,
|
| 59 |
-
which is better than bilinear interpolation in
|
| 60 |
-
"""
|
| 61 |
-
if not isinstance(image, PIL.Image.Image):
|
| 62 |
-
image = PIL.Image.fromarray(image)
|
| 63 |
-
|
| 64 |
-
# transpose the resolution if necessary
|
| 65 |
-
W, H = image.size # new size
|
| 66 |
-
assert resolution[0] >= resolution[1]
|
| 67 |
-
if H > 1.1 * W:
|
| 68 |
-
# image is portrait mode
|
| 69 |
-
resolution = resolution[::-1]
|
| 70 |
-
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
| 71 |
-
# image is square, so we chose (portrait, landscape) randomly
|
| 72 |
-
if rng.integers(2) and self.aug_portrait_or_landscape:
|
| 73 |
-
resolution = resolution[::-1]
|
| 74 |
-
|
| 75 |
-
# high-quality Lanczos down-scaling
|
| 76 |
-
target_resolution = np.array(resolution)
|
| 77 |
-
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 78 |
-
|
| 79 |
-
# actual cropping (if necessary) with bilinear interpolation
|
| 80 |
-
offset_factor = 0.5
|
| 81 |
-
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
| 82 |
-
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 83 |
-
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 84 |
-
|
| 85 |
-
return image, depthmap, intrinsics2
|
| 86 |
-
|
| 87 |
-
def generate_crops_from_pair(self, view1, view2, resolution, aug_crop_arg, n_crops=4, rng=np.random):
|
| 88 |
-
views = [view1, view2]
|
| 89 |
-
|
| 90 |
-
if aug_crop_arg is False:
|
| 91 |
-
# compatibility
|
| 92 |
-
for i in range(2):
|
| 93 |
-
view = views[i]
|
| 94 |
-
view['img'], view['depthmap'], view['camera_intrinsics'] = self._crop_resize_if_necessary(view['img'],
|
| 95 |
-
view['depthmap'],
|
| 96 |
-
view['camera_intrinsics'],
|
| 97 |
-
resolution,
|
| 98 |
-
rng=rng)
|
| 99 |
-
view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
|
| 100 |
-
view['camera_intrinsics'],
|
| 101 |
-
view['camera_pose'])
|
| 102 |
-
return
|
| 103 |
-
|
| 104 |
-
# extract correspondences
|
| 105 |
-
corres = extract_correspondences_from_pts3d(*views, target_n_corres=None, rng=rng)
|
| 106 |
-
|
| 107 |
-
# generate 4 random crops in each view
|
| 108 |
-
view_crops = []
|
| 109 |
-
crops_resolution = []
|
| 110 |
-
corres_msks = []
|
| 111 |
-
for i in range(2):
|
| 112 |
-
|
| 113 |
-
if aug_crop_arg == 'auto':
|
| 114 |
-
S = min(views[i]['img'].size)
|
| 115 |
-
R = min(resolution)
|
| 116 |
-
aug_crop = S * (S - R) // R
|
| 117 |
-
aug_crop = max(.1 * S, aug_crop) # for cropping: augment scale of at least 10%, and more if possible
|
| 118 |
-
else:
|
| 119 |
-
aug_crop = aug_crop_arg
|
| 120 |
-
|
| 121 |
-
# tranpose the target resolution if necessary
|
| 122 |
-
assert resolution[0] >= resolution[1]
|
| 123 |
-
W, H = imsize = views[i]['img'].size
|
| 124 |
-
crop_resolution = resolution
|
| 125 |
-
if H > 1.1 * W:
|
| 126 |
-
# image is portrait mode
|
| 127 |
-
crop_resolution = resolution[::-1]
|
| 128 |
-
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
| 129 |
-
# image is square, so we chose (portrait, landscape) randomly
|
| 130 |
-
if rng.integers(2):
|
| 131 |
-
crop_resolution = resolution[::-1]
|
| 132 |
-
|
| 133 |
-
crops = gen_random_crops(imsize, n_crops, crop_resolution, aug_crop=aug_crop, rng=rng)
|
| 134 |
-
view_crops.append(crops)
|
| 135 |
-
crops_resolution.append(crop_resolution)
|
| 136 |
-
|
| 137 |
-
# compute correspondences
|
| 138 |
-
corres_msks.append(in2d_rect(corres[i], crops))
|
| 139 |
-
|
| 140 |
-
# compute IoU for each
|
| 141 |
-
intersection = np.float32(corres_msks[0]).T @ np.float32(corres_msks[1])
|
| 142 |
-
# select best pair of crops
|
| 143 |
-
best = np.unravel_index(intersection.argmax(), (n_crops, n_crops))
|
| 144 |
-
crops = [view_crops[i][c] for i, c in enumerate(best)]
|
| 145 |
-
|
| 146 |
-
# crop with the homography
|
| 147 |
-
for i in range(2):
|
| 148 |
-
view = views[i]
|
| 149 |
-
imsize, K_new, R, H = crop_to_homography(view['camera_intrinsics'], crops[i], crops_resolution[i])
|
| 150 |
-
# imsize, K_new, H = upscale_homography(imsize, resolution, K_new, H)
|
| 151 |
-
|
| 152 |
-
# update camera params
|
| 153 |
-
K_old = view['camera_intrinsics']
|
| 154 |
-
view['camera_intrinsics'] = K_new
|
| 155 |
-
view['camera_pose'] = view['camera_pose'].copy()
|
| 156 |
-
view['camera_pose'][:3, :3] = view['camera_pose'][:3, :3] @ R
|
| 157 |
-
|
| 158 |
-
# apply homography to image and depthmap
|
| 159 |
-
homo8 = (H / H[2, 2]).ravel().tolist()[:8]
|
| 160 |
-
view['img'] = view['img'].transform(imsize, Image.Transform.PERSPECTIVE,
|
| 161 |
-
homo8,
|
| 162 |
-
resample=Image.Resampling.BICUBIC)
|
| 163 |
-
|
| 164 |
-
depthmap2 = depthmap_to_camera_coordinates(view['depthmap'], K_old)[0] @ R[:, 2]
|
| 165 |
-
view['depthmap'] = np.array(Image.fromarray(depthmap2).transform(
|
| 166 |
-
imsize, Image.Transform.PERSPECTIVE, homo8))
|
| 167 |
-
|
| 168 |
-
if 'track_labels' in view:
|
| 169 |
-
# convert from uint64 --> uint32, because PIL.Image cannot handle uint64
|
| 170 |
-
mapping, track_labels = np.unique(view['track_labels'], return_inverse=True)
|
| 171 |
-
track_labels = track_labels.astype(np.uint32).reshape(view['track_labels'].shape)
|
| 172 |
-
|
| 173 |
-
# homography transformation
|
| 174 |
-
res = np.array(Image.fromarray(track_labels).transform(imsize, Image.Transform.PERSPECTIVE, homo8))
|
| 175 |
-
view['track_labels'] = mapping[res] # mapping back to uint64
|
| 176 |
-
|
| 177 |
-
# recompute 3d points from scratch
|
| 178 |
-
view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
|
| 179 |
-
view['camera_intrinsics'],
|
| 180 |
-
view['camera_pose'])
|
| 181 |
-
|
| 182 |
-
def __getitem__(self, idx):
|
| 183 |
-
if isinstance(idx, tuple):
|
| 184 |
-
# the idx is specifying the aspect-ratio
|
| 185 |
-
idx, ar_idx = idx
|
| 186 |
-
else:
|
| 187 |
-
assert len(self._resolutions) == 1
|
| 188 |
-
ar_idx = 0
|
| 189 |
-
|
| 190 |
-
# set-up the rng
|
| 191 |
-
if self.seed: # reseed for each __getitem__
|
| 192 |
-
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 193 |
-
elif not hasattr(self, '_rng'):
|
| 194 |
-
seed = torch.initial_seed() # this is different for each dataloader process
|
| 195 |
-
self._rng = np.random.default_rng(seed=seed)
|
| 196 |
-
|
| 197 |
-
# over-loaded code
|
| 198 |
-
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 199 |
-
views = self._get_views(idx, resolution, self._rng)
|
| 200 |
-
assert len(views) == self.num_views
|
| 201 |
-
|
| 202 |
-
for v, view in enumerate(views):
|
| 203 |
-
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 204 |
-
view['idx'] = (idx, ar_idx, v)
|
| 205 |
-
view['is_metric_scale'] = self.is_metric_scale
|
| 206 |
-
|
| 207 |
-
assert 'camera_intrinsics' in view
|
| 208 |
-
if 'camera_pose' not in view:
|
| 209 |
-
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 210 |
-
else:
|
| 211 |
-
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
| 212 |
-
assert 'pts3d' not in view
|
| 213 |
-
assert 'valid_mask' not in view
|
| 214 |
-
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
| 215 |
-
|
| 216 |
-
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 217 |
-
|
| 218 |
-
view['pts3d'] = pts3d
|
| 219 |
-
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 220 |
-
|
| 221 |
-
self.generate_crops_from_pair(views[0], views[1], resolution=resolution,
|
| 222 |
-
aug_crop_arg=self.aug_crop,
|
| 223 |
-
n_crops=self.n_tentative_crops,
|
| 224 |
-
rng=self._rng)
|
| 225 |
-
for v, view in enumerate(views):
|
| 226 |
-
# encode the image
|
| 227 |
-
width, height = view['img'].size
|
| 228 |
-
view['true_shape'] = np.int32((height, width))
|
| 229 |
-
view['img'] = self.transform(view['img'])
|
| 230 |
-
# Pixels for which depth is fundamentally undefined
|
| 231 |
-
view['sky_mask'] = (view['depthmap'] < 0)
|
| 232 |
-
|
| 233 |
-
if self.aug_swap:
|
| 234 |
-
self._swap_view_aug(views)
|
| 235 |
-
|
| 236 |
-
if self.aug_monocular:
|
| 237 |
-
if self._rng.random() < self.aug_monocular:
|
| 238 |
-
views = [copy.deepcopy(views[0]) for _ in range(len(views))]
|
| 239 |
-
|
| 240 |
-
# automatic extraction of correspondences from pts3d + pose
|
| 241 |
-
if self.n_corres > 0 and ('corres' not in view):
|
| 242 |
-
corres1, corres2, valid = extract_correspondences_from_pts3d(*views, self.n_corres,
|
| 243 |
-
self._rng, nneg=self.nneg)
|
| 244 |
-
views[0]['corres'] = corres1
|
| 245 |
-
views[1]['corres'] = corres2
|
| 246 |
-
views[0]['valid_corres'] = valid
|
| 247 |
-
views[1]['valid_corres'] = valid
|
| 248 |
-
|
| 249 |
-
if self.aug_rot90 is False:
|
| 250 |
-
pass
|
| 251 |
-
elif self.aug_rot90 == 'same':
|
| 252 |
-
rotate_90(views, k=self._rng.choice(4))
|
| 253 |
-
elif self.aug_rot90 == 'diff':
|
| 254 |
-
rotate_90(views[:1], k=self._rng.choice(4))
|
| 255 |
-
rotate_90(views[1:], k=self._rng.choice(4))
|
| 256 |
-
else:
|
| 257 |
-
raise ValueError(f'Bad value for {self.aug_rot90=}')
|
| 258 |
-
|
| 259 |
-
# check data-types metric_scale
|
| 260 |
-
for v, view in enumerate(views):
|
| 261 |
-
if 'corres' not in view:
|
| 262 |
-
view['corres'] = np.full((self.n_corres, 2), np.nan, dtype=np.float32)
|
| 263 |
-
|
| 264 |
-
# check all datatypes
|
| 265 |
-
for key, val in view.items():
|
| 266 |
-
res, err_msg = is_good_type(key, val)
|
| 267 |
-
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 268 |
-
K = view['camera_intrinsics']
|
| 269 |
-
|
| 270 |
-
# check shapes
|
| 271 |
-
assert view['depthmap'].shape == view['img'].shape[1:]
|
| 272 |
-
assert view['depthmap'].shape == view['pts3d'].shape[:2]
|
| 273 |
-
assert view['depthmap'].shape == view['valid_mask'].shape
|
| 274 |
-
|
| 275 |
-
# last thing done!
|
| 276 |
-
for view in views:
|
| 277 |
-
# transpose to make sure all views are the same size
|
| 278 |
-
transpose_to_landscape(view)
|
| 279 |
-
# this allows to check whether the RNG is is the same state each time
|
| 280 |
-
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 281 |
-
|
| 282 |
-
return views
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
def transpose_to_landscape(view, revert=False):
|
| 286 |
-
height, width = view['true_shape']
|
| 287 |
-
|
| 288 |
-
if width < height:
|
| 289 |
-
if revert:
|
| 290 |
-
height, width = width, height
|
| 291 |
-
|
| 292 |
-
# rectify portrait to landscape
|
| 293 |
-
assert view['img'].shape == (3, height, width)
|
| 294 |
-
view['img'] = view['img'].swapaxes(1, 2)
|
| 295 |
-
|
| 296 |
-
assert view['valid_mask'].shape == (height, width)
|
| 297 |
-
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
|
| 298 |
-
|
| 299 |
-
assert view['sky_mask'].shape == (height, width)
|
| 300 |
-
view['sky_mask'] = view['sky_mask'].swapaxes(0, 1)
|
| 301 |
-
|
| 302 |
-
assert view['depthmap'].shape == (height, width)
|
| 303 |
-
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
|
| 304 |
-
|
| 305 |
-
assert view['pts3d'].shape == (height, width, 3)
|
| 306 |
-
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
|
| 307 |
-
|
| 308 |
-
# transpose x and y pixels
|
| 309 |
-
view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
|
| 310 |
-
|
| 311 |
-
# transpose correspondences x and y
|
| 312 |
-
view['corres'] = view['corres'][:, [1, 0]]
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def rotate_90(views, k=1):
|
| 316 |
-
from scipy.spatial.transform import Rotation
|
| 317 |
-
# print('rotation =', k)
|
| 318 |
-
|
| 319 |
-
RT = np.eye(4, dtype=np.float32)
|
| 320 |
-
RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
|
| 321 |
-
|
| 322 |
-
for view in views:
|
| 323 |
-
view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
|
| 324 |
-
view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
|
| 325 |
-
view['camera_pose'] = view['camera_pose'] @ RT
|
| 326 |
-
|
| 327 |
-
RT2 = np.eye(3, dtype=np.float32)
|
| 328 |
-
RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
|
| 329 |
-
H, W = view['depthmap'].shape
|
| 330 |
-
if k % 4 == 0:
|
| 331 |
-
pass
|
| 332 |
-
elif k % 4 == 1:
|
| 333 |
-
# top-left (0,0) pixel becomes (0,H-1)
|
| 334 |
-
RT2[:2, 2] = (0, H - 1)
|
| 335 |
-
elif k % 4 == 2:
|
| 336 |
-
# top-left (0,0) pixel becomes (W-1,H-1)
|
| 337 |
-
RT2[:2, 2] = (W - 1, H - 1)
|
| 338 |
-
elif k % 4 == 3:
|
| 339 |
-
# top-left (0,0) pixel becomes (W-1,0)
|
| 340 |
-
RT2[:2, 2] = (W - 1, 0)
|
| 341 |
-
else:
|
| 342 |
-
raise ValueError(f'Bad value for {k=}')
|
| 343 |
-
|
| 344 |
-
view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
|
| 345 |
-
if k % 2 == 1:
|
| 346 |
-
K = view['camera_intrinsics']
|
| 347 |
-
np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
|
| 348 |
-
|
| 349 |
-
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 350 |
-
view['pts3d'] = pts3d
|
| 351 |
-
view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
|
| 352 |
-
view['sky_mask'] = np.rot90(view['sky_mask'], k=k).copy()
|
| 353 |
-
|
| 354 |
-
view['corres'] = geotrf(RT2, view['corres']).round().astype(view['corres'].dtype)
|
| 355 |
-
view['true_shape'] = np.int32((H, W))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/datasets/utils/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/datasets/utils/cropping.py
DELETED
|
@@ -1,219 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# cropping/match extraction
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import numpy as np
|
| 8 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 9 |
-
from dust3r.utils.device import to_numpy
|
| 10 |
-
from dust3r.utils.geometry import inv, geotrf
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
| 14 |
-
is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)))
|
| 15 |
-
pos1 = is_reciprocal1.nonzero()[0]
|
| 16 |
-
pos2 = corres_1_to_2[pos1]
|
| 17 |
-
if ret_recip:
|
| 18 |
-
return is_reciprocal1, pos1, pos2
|
| 19 |
-
return pos1, pos2
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0):
|
| 23 |
-
view1, view2 = to_numpy((view1, view2))
|
| 24 |
-
# project pixels from image1 --> 3d points --> image2 pixels
|
| 25 |
-
shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
|
| 26 |
-
shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
|
| 27 |
-
|
| 28 |
-
# compute reciprocal correspondences:
|
| 29 |
-
# pos1 == valid pixels (correspondences) in image1
|
| 30 |
-
is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True)
|
| 31 |
-
is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)))
|
| 32 |
-
|
| 33 |
-
if target_n_corres is None:
|
| 34 |
-
if ret_xy:
|
| 35 |
-
pos1 = unravel_xy(pos1, shape1)
|
| 36 |
-
pos2 = unravel_xy(pos2, shape2)
|
| 37 |
-
return pos1, pos2
|
| 38 |
-
|
| 39 |
-
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
| 40 |
-
target_n_positives = int(target_n_corres * (1 - nneg))
|
| 41 |
-
n_positives = min(len(pos1), target_n_positives)
|
| 42 |
-
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
| 43 |
-
|
| 44 |
-
if n_negatives + n_positives != target_n_corres:
|
| 45 |
-
# should be really rare => when there are not enough negatives
|
| 46 |
-
# in that case, break nneg and add a few more positives ?
|
| 47 |
-
n_positives = target_n_corres - n_negatives
|
| 48 |
-
assert n_positives <= len(pos1)
|
| 49 |
-
|
| 50 |
-
assert n_positives <= len(pos1)
|
| 51 |
-
assert n_positives <= len(pos2)
|
| 52 |
-
assert n_negatives <= (~is_reciprocal1).sum()
|
| 53 |
-
assert n_negatives <= (~is_reciprocal2).sum()
|
| 54 |
-
assert n_positives + n_negatives == target_n_corres
|
| 55 |
-
|
| 56 |
-
valid = np.ones(n_positives, dtype=bool)
|
| 57 |
-
if n_positives < len(pos1):
|
| 58 |
-
# random sub-sampling of valid correspondences
|
| 59 |
-
perm = rng.permutation(len(pos1))[:n_positives]
|
| 60 |
-
pos1 = pos1[perm]
|
| 61 |
-
pos2 = pos2[perm]
|
| 62 |
-
|
| 63 |
-
if n_negatives > 0:
|
| 64 |
-
# add false correspondences if not enough
|
| 65 |
-
def norm(p): return p / p.sum()
|
| 66 |
-
pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))]
|
| 67 |
-
pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))]
|
| 68 |
-
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
| 69 |
-
|
| 70 |
-
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 71 |
-
if ret_xy:
|
| 72 |
-
pos1 = unravel_xy(pos1, shape1)
|
| 73 |
-
pos2 = unravel_xy(pos2, shape2)
|
| 74 |
-
return pos1, pos2, valid
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def reproject_view(pts3d, view2):
|
| 78 |
-
shape = view2['pts3d'].shape[:2]
|
| 79 |
-
return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def reproject(pts3d, K, world2cam, shape):
|
| 83 |
-
H, W, THREE = pts3d.shape
|
| 84 |
-
assert THREE == 3
|
| 85 |
-
|
| 86 |
-
# reproject in camera2 space
|
| 87 |
-
with np.errstate(divide='ignore', invalid='ignore'):
|
| 88 |
-
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 89 |
-
|
| 90 |
-
# quantize to pixel positions
|
| 91 |
-
return (H, W), ravel_xy(pos, shape)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def ravel_xy(pos, shape):
|
| 95 |
-
H, W = shape
|
| 96 |
-
with np.errstate(invalid='ignore'):
|
| 97 |
-
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 98 |
-
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
|
| 99 |
-
return quantized_pos
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def unravel_xy(pos, shape):
|
| 103 |
-
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 104 |
-
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _rotation_origin_to_pt(target):
|
| 108 |
-
""" Align the origin (0,0,1) with the target point (x,y,1) in projective space.
|
| 109 |
-
Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z.
|
| 110 |
-
"""
|
| 111 |
-
from scipy.spatial.transform import Rotation
|
| 112 |
-
x, y = target
|
| 113 |
-
rot_z = np.arctan2(y, x)
|
| 114 |
-
rot_y = np.arctan(np.linalg.norm(target))
|
| 115 |
-
R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix()
|
| 116 |
-
return R
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _dotmv(Trf, pts, ncol=None, norm=False):
|
| 120 |
-
assert Trf.ndim >= 2
|
| 121 |
-
ncol = ncol or pts.shape[-1]
|
| 122 |
-
|
| 123 |
-
# adapt shape if necessary
|
| 124 |
-
output_reshape = pts.shape[:-1]
|
| 125 |
-
if Trf.ndim >= 3:
|
| 126 |
-
n = Trf.ndim - 2
|
| 127 |
-
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
|
| 128 |
-
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 129 |
-
|
| 130 |
-
if pts.ndim > Trf.ndim:
|
| 131 |
-
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
| 132 |
-
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 133 |
-
elif pts.ndim == 2:
|
| 134 |
-
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
| 135 |
-
pts = pts[:, None, :]
|
| 136 |
-
|
| 137 |
-
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 138 |
-
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 139 |
-
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 140 |
-
|
| 141 |
-
elif pts.shape[-1] == Trf.shape[-1]:
|
| 142 |
-
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 143 |
-
pts = pts @ Trf
|
| 144 |
-
else:
|
| 145 |
-
pts = Trf @ pts.T
|
| 146 |
-
if pts.ndim >= 2:
|
| 147 |
-
pts = pts.swapaxes(-1, -2)
|
| 148 |
-
|
| 149 |
-
if norm:
|
| 150 |
-
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 151 |
-
if norm != 1:
|
| 152 |
-
pts *= norm
|
| 153 |
-
|
| 154 |
-
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 155 |
-
return res
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
def crop_to_homography(K, crop, target_size=None):
|
| 159 |
-
""" Given an image and its intrinsics,
|
| 160 |
-
we want to replicate a rectangular crop with an homography,
|
| 161 |
-
so that the principal point of the new 'crop' is centered.
|
| 162 |
-
"""
|
| 163 |
-
# build intrinsics for the crop
|
| 164 |
-
crop = np.round(crop)
|
| 165 |
-
crop_size = crop[2:] - crop[:2]
|
| 166 |
-
K2 = K.copy() # same focal
|
| 167 |
-
K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered
|
| 168 |
-
|
| 169 |
-
# find which corner is the most far-away from current principal point
|
| 170 |
-
# so that the final homography does not go over the image borders
|
| 171 |
-
corners = crop.reshape(-1, 2)
|
| 172 |
-
corner_idx = np.abs(corners - K[:2, 2]).argmax(0)
|
| 173 |
-
corner = corners[corner_idx, [0, 1]]
|
| 174 |
-
# align with the corresponding corner from the target view
|
| 175 |
-
corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx]
|
| 176 |
-
|
| 177 |
-
old_pt = _dotmv(np.linalg.inv(K), corner, norm=1)
|
| 178 |
-
new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1)
|
| 179 |
-
R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt))
|
| 180 |
-
|
| 181 |
-
if target_size is not None:
|
| 182 |
-
imsize = target_size
|
| 183 |
-
target_size = np.asarray(target_size)
|
| 184 |
-
scaling = min(target_size / crop_size)
|
| 185 |
-
K2[:2] *= scaling
|
| 186 |
-
K2[:2, 2] = target_size / 2
|
| 187 |
-
else:
|
| 188 |
-
imsize = tuple(np.int32(crop_size).tolist())
|
| 189 |
-
|
| 190 |
-
return imsize, K2, R, K @ R @ np.linalg.inv(K2)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random):
|
| 194 |
-
""" Generate random crops of size=resolution,
|
| 195 |
-
for an input image upscaled to (imsize + randint(0 , aug_crop))
|
| 196 |
-
"""
|
| 197 |
-
resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution)
|
| 198 |
-
|
| 199 |
-
# (virtually) upscale the input image
|
| 200 |
-
# scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize))
|
| 201 |
-
scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize))))
|
| 202 |
-
imsize2 = np.int32(np.array(imsize) * scaling)
|
| 203 |
-
|
| 204 |
-
# generate some random crops
|
| 205 |
-
topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop)
|
| 206 |
-
crops = np.c_[topleft, topleft + resolution_crop]
|
| 207 |
-
# print(f"{scaling=}, {topleft=}")
|
| 208 |
-
# reduce the resolution to come back to original size
|
| 209 |
-
crops /= scaling
|
| 210 |
-
return crops
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
def in2d_rect(corres, crops):
|
| 214 |
-
# corres = (N,2)
|
| 215 |
-
# crops = (M,4)
|
| 216 |
-
# output = (N, M)
|
| 217 |
-
is_sup = (corres[:, None] >= crops[None, :, 0:2])
|
| 218 |
-
is_inf = (corres[:, None] < crops[None, :, 2:4])
|
| 219 |
-
return (is_sup & is_inf).all(axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/fast_nn.py
DELETED
|
@@ -1,221 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# MASt3R Fast Nearest Neighbor
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import torch
|
| 8 |
-
import numpy as np
|
| 9 |
-
import math
|
| 10 |
-
from scipy.spatial import KDTree
|
| 11 |
-
|
| 12 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 13 |
-
from ..dust3r.dust3r.utils.device import to_numpy, todevice # noqa
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@torch.no_grad()
|
| 17 |
-
def bruteforce_reciprocal_nns(A, B, device='cuda', block_size=None, dist='l2'):
|
| 18 |
-
if isinstance(A, np.ndarray):
|
| 19 |
-
A = torch.from_numpy(A).to(device)
|
| 20 |
-
if isinstance(B, np.ndarray):
|
| 21 |
-
B = torch.from_numpy(B).to(device)
|
| 22 |
-
|
| 23 |
-
A = A.to(device)
|
| 24 |
-
B = B.to(device)
|
| 25 |
-
|
| 26 |
-
if dist == 'l2':
|
| 27 |
-
dist_func = torch.cdist
|
| 28 |
-
argmin = torch.min
|
| 29 |
-
elif dist == 'dot':
|
| 30 |
-
def dist_func(A, B):
|
| 31 |
-
return A @ B.T
|
| 32 |
-
|
| 33 |
-
def argmin(X, dim):
|
| 34 |
-
sim, nn = torch.max(X, dim=dim)
|
| 35 |
-
return sim.neg_(), nn
|
| 36 |
-
else:
|
| 37 |
-
raise ValueError(f'Unknown {dist=}')
|
| 38 |
-
|
| 39 |
-
if block_size is None or len(A) * len(B) <= block_size**2:
|
| 40 |
-
dists = dist_func(A, B)
|
| 41 |
-
_, nn_A = argmin(dists, dim=1)
|
| 42 |
-
_, nn_B = argmin(dists, dim=0)
|
| 43 |
-
else:
|
| 44 |
-
dis_A = torch.full((A.shape[0],), float('inf'), device=device, dtype=A.dtype)
|
| 45 |
-
dis_B = torch.full((B.shape[0],), float('inf'), device=device, dtype=B.dtype)
|
| 46 |
-
nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64)
|
| 47 |
-
nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64)
|
| 48 |
-
number_of_iteration_A = math.ceil(A.shape[0] / block_size)
|
| 49 |
-
number_of_iteration_B = math.ceil(B.shape[0] / block_size)
|
| 50 |
-
|
| 51 |
-
for i in range(number_of_iteration_A):
|
| 52 |
-
A_i = A[i * block_size:(i + 1) * block_size]
|
| 53 |
-
for j in range(number_of_iteration_B):
|
| 54 |
-
B_j = B[j * block_size:(j + 1) * block_size]
|
| 55 |
-
dists_blk = dist_func(A_i, B_j) # A, B, 1
|
| 56 |
-
# dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size]
|
| 57 |
-
min_A_i, argmin_A_i = argmin(dists_blk, dim=1)
|
| 58 |
-
min_B_j, argmin_B_j = argmin(dists_blk, dim=0)
|
| 59 |
-
|
| 60 |
-
col_mask = min_A_i < dis_A[i * block_size:(i + 1) * block_size]
|
| 61 |
-
line_mask = min_B_j < dis_B[j * block_size:(j + 1) * block_size]
|
| 62 |
-
|
| 63 |
-
dis_A[i * block_size:(i + 1) * block_size][col_mask] = min_A_i[col_mask]
|
| 64 |
-
dis_B[j * block_size:(j + 1) * block_size][line_mask] = min_B_j[line_mask]
|
| 65 |
-
|
| 66 |
-
nn_A[i * block_size:(i + 1) * block_size][col_mask] = argmin_A_i[col_mask] + (j * block_size)
|
| 67 |
-
nn_B[j * block_size:(j + 1) * block_size][line_mask] = argmin_B_j[line_mask] + (i * block_size)
|
| 68 |
-
nn_A = nn_A.cpu().numpy()
|
| 69 |
-
nn_B = nn_B.cpu().numpy()
|
| 70 |
-
return nn_A, nn_B
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class cdistMatcher:
|
| 74 |
-
def __init__(self, db_pts, device='cuda'):
|
| 75 |
-
self.db_pts = db_pts.to(device)
|
| 76 |
-
self.device = device
|
| 77 |
-
|
| 78 |
-
def query(self, queries, k=1, **kw):
|
| 79 |
-
assert k == 1
|
| 80 |
-
if queries.numel() == 0:
|
| 81 |
-
return None, []
|
| 82 |
-
nnA, nnB = bruteforce_reciprocal_nns(queries, self.db_pts, device=self.device, **kw)
|
| 83 |
-
dis = None
|
| 84 |
-
return dis, nnA
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False):
|
| 88 |
-
assert idx1.dtype == idx2.dtype == np.int32
|
| 89 |
-
|
| 90 |
-
# unique and sort along idx1
|
| 91 |
-
corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index)
|
| 92 |
-
if ret_index:
|
| 93 |
-
corres, indices = corres
|
| 94 |
-
xy2, xy1 = corres[:, None].view(np.int32).T
|
| 95 |
-
|
| 96 |
-
if ret_xy:
|
| 97 |
-
assert shape1 and shape2
|
| 98 |
-
xy1 = np.unravel_index(xy1, shape1)
|
| 99 |
-
xy2 = np.unravel_index(xy2, shape2)
|
| 100 |
-
if ret_xy != 'y_x':
|
| 101 |
-
xy1 = xy1[0].base[:, ::-1]
|
| 102 |
-
xy2 = xy2[0].base[:, ::-1]
|
| 103 |
-
|
| 104 |
-
if ret_index:
|
| 105 |
-
return xy1, xy2, indices
|
| 106 |
-
return xy1, xy2
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False,
|
| 110 |
-
device='cuda', **matcher_kw):
|
| 111 |
-
H1, W1, DIM1 = pts1.shape
|
| 112 |
-
H2, W2, DIM2 = pts2.shape
|
| 113 |
-
assert DIM1 == DIM2
|
| 114 |
-
|
| 115 |
-
pts1 = pts1.reshape(-1, DIM1)
|
| 116 |
-
pts2 = pts2.reshape(-1, DIM2)
|
| 117 |
-
|
| 118 |
-
if isinstance(subsample_or_initxy1, int) and pixel_tol == 0:
|
| 119 |
-
S = subsample_or_initxy1
|
| 120 |
-
y1, x1 = np.mgrid[S // 2:H1:S, S // 2:W1:S].reshape(2, -1)
|
| 121 |
-
max_iter = 10
|
| 122 |
-
else:
|
| 123 |
-
x1, y1 = subsample_or_initxy1
|
| 124 |
-
if isinstance(x1, torch.Tensor):
|
| 125 |
-
x1 = x1.cpu().numpy()
|
| 126 |
-
if isinstance(y1, torch.Tensor):
|
| 127 |
-
y1 = y1.cpu().numpy()
|
| 128 |
-
max_iter = 1
|
| 129 |
-
|
| 130 |
-
xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons
|
| 131 |
-
xy2 = np.full_like(xy1, -1)
|
| 132 |
-
old_xy1 = xy1.copy()
|
| 133 |
-
old_xy2 = xy2.copy()
|
| 134 |
-
|
| 135 |
-
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type.startswith('cuda')):
|
| 136 |
-
pts1 = pts1.to(device)
|
| 137 |
-
pts2 = pts2.to(device)
|
| 138 |
-
tree1 = cdistMatcher(pts1, device=device)
|
| 139 |
-
tree2 = cdistMatcher(pts2, device=device)
|
| 140 |
-
else:
|
| 141 |
-
pts1, pts2 = to_numpy((pts1, pts2))
|
| 142 |
-
tree1 = KDTree(pts1)
|
| 143 |
-
tree2 = KDTree(pts2)
|
| 144 |
-
|
| 145 |
-
notyet = np.ones(len(xy1), dtype=bool)
|
| 146 |
-
basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None
|
| 147 |
-
|
| 148 |
-
niter = 0
|
| 149 |
-
# n_notyet = [len(notyet)]
|
| 150 |
-
while notyet.any():
|
| 151 |
-
_, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw))
|
| 152 |
-
if not ret_basin:
|
| 153 |
-
notyet &= (old_xy2 != xy2) # remove points that have converged
|
| 154 |
-
|
| 155 |
-
_, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw))
|
| 156 |
-
if ret_basin:
|
| 157 |
-
basin[old_xy1[notyet]] = xy1[notyet]
|
| 158 |
-
notyet &= (old_xy1 != xy1) # remove points that have converged
|
| 159 |
-
|
| 160 |
-
# n_notyet.append(notyet.sum())
|
| 161 |
-
niter += 1
|
| 162 |
-
if niter >= max_iter:
|
| 163 |
-
break
|
| 164 |
-
|
| 165 |
-
old_xy2[:] = xy2
|
| 166 |
-
old_xy1[:] = xy1
|
| 167 |
-
|
| 168 |
-
# print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter])))
|
| 169 |
-
|
| 170 |
-
if pixel_tol > 0:
|
| 171 |
-
# in case we only want to match some specific points
|
| 172 |
-
# and still have some way of checking reciprocity
|
| 173 |
-
old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base
|
| 174 |
-
new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base
|
| 175 |
-
dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1)
|
| 176 |
-
converged = dis < pixel_tol
|
| 177 |
-
if not isinstance(subsample_or_initxy1, int):
|
| 178 |
-
xy1 = old_xy1 # replace new points by old ones
|
| 179 |
-
else:
|
| 180 |
-
converged = ~notyet # converged correspondences
|
| 181 |
-
|
| 182 |
-
# keep only unique correspondences, and sort on xy1
|
| 183 |
-
xy1, xy2 = merge_corres(xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy)
|
| 184 |
-
if ret_basin:
|
| 185 |
-
return xy1, xy2, basin
|
| 186 |
-
return xy1, xy2
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def extract_correspondences_nonsym(A, B, confA, confB, subsample=8, device=None, ptmap_key='pred_desc', pixel_tol=0):
|
| 190 |
-
if '3d' in ptmap_key:
|
| 191 |
-
opt = dict(device='cpu', workers=32)
|
| 192 |
-
else:
|
| 193 |
-
opt = dict(device=device, dist='dot', block_size=2**13)
|
| 194 |
-
|
| 195 |
-
# matching the two pairs
|
| 196 |
-
idx1 = []
|
| 197 |
-
idx2 = []
|
| 198 |
-
# merge corres from opposite pairs
|
| 199 |
-
HA, WA = A.shape[:2]
|
| 200 |
-
HB, WB = B.shape[:2]
|
| 201 |
-
if pixel_tol == 0:
|
| 202 |
-
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 203 |
-
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 204 |
-
else:
|
| 205 |
-
S = subsample
|
| 206 |
-
yA, xA = np.mgrid[S // 2:HA:S, S // 2:WA:S].reshape(2, -1)
|
| 207 |
-
yB, xB = np.mgrid[S // 2:HB:S, S // 2:WB:S].reshape(2, -1)
|
| 208 |
-
|
| 209 |
-
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt)
|
| 210 |
-
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt)
|
| 211 |
-
|
| 212 |
-
idx1 = np.r_[nn1to2[0], nn2to1[1]]
|
| 213 |
-
idx2 = np.r_[nn1to2[1], nn2to1[0]]
|
| 214 |
-
|
| 215 |
-
c1 = confA.ravel()[idx1]
|
| 216 |
-
c2 = confB.ravel()[idx2]
|
| 217 |
-
|
| 218 |
-
xy1, xy2, idx = merge_corres(idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True)
|
| 219 |
-
conf = np.minimum(c1[idx], c2[idx])
|
| 220 |
-
corres = (xy1.copy(), xy2.copy(), conf)
|
| 221 |
-
return todevice(corres, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/losses.py
DELETED
|
@@ -1,514 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# Implementation of MASt3R training losses
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import numpy as np
|
| 10 |
-
from sklearn.metrics import average_precision_score
|
| 11 |
-
|
| 12 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 13 |
-
from dust3r.losses import BaseCriterion, Criterion, MultiLoss, Sum, ConfLoss
|
| 14 |
-
from dust3r.losses import Regr3D as Regr3D_dust3r
|
| 15 |
-
from dust3r.utils.geometry import (geotrf, inv, normalize_pointcloud)
|
| 16 |
-
from dust3r.inference import get_pred_pts3d
|
| 17 |
-
from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def apply_log_to_norm(xyz):
|
| 21 |
-
d = xyz.norm(dim=-1, keepdim=True)
|
| 22 |
-
xyz = xyz / d.clip(min=1e-8)
|
| 23 |
-
xyz = xyz * torch.log1p(d)
|
| 24 |
-
return xyz
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class Regr3D (Regr3D_dust3r):
|
| 28 |
-
def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, opt_fit_gt=False,
|
| 29 |
-
sky_loss_value=2, max_metric_scale=False, loss_in_log=False):
|
| 30 |
-
self.loss_in_log = loss_in_log
|
| 31 |
-
if norm_mode.startswith('?'):
|
| 32 |
-
# do no norm pts from metric scale datasets
|
| 33 |
-
self.norm_all = False
|
| 34 |
-
self.norm_mode = norm_mode[1:]
|
| 35 |
-
else:
|
| 36 |
-
self.norm_all = True
|
| 37 |
-
self.norm_mode = norm_mode
|
| 38 |
-
super().__init__(criterion, self.norm_mode, gt_scale)
|
| 39 |
-
|
| 40 |
-
self.sky_loss_value = sky_loss_value
|
| 41 |
-
self.max_metric_scale = max_metric_scale
|
| 42 |
-
|
| 43 |
-
def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
|
| 44 |
-
# everything is normalized w.r.t. camera of view1
|
| 45 |
-
in_camera1 = inv(gt1['camera_pose'])
|
| 46 |
-
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
|
| 47 |
-
gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
|
| 48 |
-
|
| 49 |
-
valid1 = gt1['valid_mask'].clone()
|
| 50 |
-
valid2 = gt2['valid_mask'].clone()
|
| 51 |
-
|
| 52 |
-
if dist_clip is not None:
|
| 53 |
-
# points that are too far-away == invalid
|
| 54 |
-
dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
|
| 55 |
-
dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
|
| 56 |
-
valid1 = valid1 & (dis1 <= dist_clip)
|
| 57 |
-
valid2 = valid2 & (dis2 <= dist_clip)
|
| 58 |
-
|
| 59 |
-
if self.loss_in_log == 'before':
|
| 60 |
-
# this only make sense when depth_mode == 'linear'
|
| 61 |
-
gt_pts1 = apply_log_to_norm(gt_pts1)
|
| 62 |
-
gt_pts2 = apply_log_to_norm(gt_pts2)
|
| 63 |
-
|
| 64 |
-
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False).clone()
|
| 65 |
-
pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True).clone()
|
| 66 |
-
|
| 67 |
-
if not self.norm_all:
|
| 68 |
-
if self.max_metric_scale:
|
| 69 |
-
B = valid1.shape[0]
|
| 70 |
-
# valid1: B, H, W
|
| 71 |
-
# torch.linalg.norm(gt_pts1, dim=-1) -> B, H, W
|
| 72 |
-
# dist1_to_cam1 -> reshape to B, H*W
|
| 73 |
-
dist1_to_cam1 = torch.where(valid1, torch.linalg.norm(gt_pts1, dim=-1), 0).view(B, -1)
|
| 74 |
-
dist2_to_cam1 = torch.where(valid2, torch.linalg.norm(gt_pts2, dim=-1), 0).view(B, -1)
|
| 75 |
-
|
| 76 |
-
# is_metric_scale: B
|
| 77 |
-
# dist1_to_cam1.max(dim=-1).values -> B
|
| 78 |
-
gt1['is_metric_scale'] = gt1['is_metric_scale'] \
|
| 79 |
-
& (dist1_to_cam1.max(dim=-1).values < self.max_metric_scale) \
|
| 80 |
-
& (dist2_to_cam1.max(dim=-1).values < self.max_metric_scale)
|
| 81 |
-
gt2['is_metric_scale'] = gt1['is_metric_scale']
|
| 82 |
-
|
| 83 |
-
mask = ~gt1['is_metric_scale']
|
| 84 |
-
else:
|
| 85 |
-
mask = torch.ones_like(gt1['is_metric_scale'])
|
| 86 |
-
# normalize 3d points
|
| 87 |
-
if self.norm_mode and mask.any():
|
| 88 |
-
pr_pts1[mask], pr_pts2[mask] = normalize_pointcloud(pr_pts1[mask], pr_pts2[mask], self.norm_mode,
|
| 89 |
-
valid1[mask], valid2[mask])
|
| 90 |
-
|
| 91 |
-
if self.norm_mode and not self.gt_scale:
|
| 92 |
-
gt_pts1, gt_pts2, norm_factor = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode,
|
| 93 |
-
valid1, valid2, ret_factor=True)
|
| 94 |
-
# apply the same normalization to prediction
|
| 95 |
-
pr_pts1[~mask] = pr_pts1[~mask] / norm_factor[~mask]
|
| 96 |
-
pr_pts2[~mask] = pr_pts2[~mask] / norm_factor[~mask]
|
| 97 |
-
|
| 98 |
-
# return sky segmentation, making sure they don't include any labelled 3d points
|
| 99 |
-
sky1 = gt1['sky_mask'] & (~valid1)
|
| 100 |
-
sky2 = gt2['sky_mask'] & (~valid2)
|
| 101 |
-
return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, sky1, sky2, {}
|
| 102 |
-
|
| 103 |
-
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
| 104 |
-
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
| 105 |
-
self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
|
| 106 |
-
|
| 107 |
-
if self.sky_loss_value > 0:
|
| 108 |
-
assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
|
| 109 |
-
# add the sky pixel as "valid" pixels...
|
| 110 |
-
mask1 = mask1 | sky1
|
| 111 |
-
mask2 = mask2 | sky2
|
| 112 |
-
|
| 113 |
-
# loss on img1 side
|
| 114 |
-
pred_pts1 = pred_pts1[mask1]
|
| 115 |
-
gt_pts1 = gt_pts1[mask1]
|
| 116 |
-
if self.loss_in_log and self.loss_in_log != 'before':
|
| 117 |
-
# this only make sense when depth_mode == 'exp'
|
| 118 |
-
pred_pts1 = apply_log_to_norm(pred_pts1)
|
| 119 |
-
gt_pts1 = apply_log_to_norm(gt_pts1)
|
| 120 |
-
l1 = self.criterion(pred_pts1, gt_pts1)
|
| 121 |
-
|
| 122 |
-
# loss on gt2 side
|
| 123 |
-
pred_pts2 = pred_pts2[mask2]
|
| 124 |
-
gt_pts2 = gt_pts2[mask2]
|
| 125 |
-
if self.loss_in_log and self.loss_in_log != 'before':
|
| 126 |
-
pred_pts2 = apply_log_to_norm(pred_pts2)
|
| 127 |
-
gt_pts2 = apply_log_to_norm(gt_pts2)
|
| 128 |
-
l2 = self.criterion(pred_pts2, gt_pts2)
|
| 129 |
-
|
| 130 |
-
if self.sky_loss_value > 0:
|
| 131 |
-
assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
|
| 132 |
-
# ... but force the loss to be high there
|
| 133 |
-
l1 = torch.where(sky1[mask1], self.sky_loss_value, l1)
|
| 134 |
-
l2 = torch.where(sky2[mask2], self.sky_loss_value, l2)
|
| 135 |
-
self_name = type(self).__name__
|
| 136 |
-
details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())}
|
| 137 |
-
return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class Regr3D_ShiftInv (Regr3D):
|
| 141 |
-
""" Same than Regr3D but invariant to depth shift.
|
| 142 |
-
"""
|
| 143 |
-
|
| 144 |
-
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
| 145 |
-
# compute unnormalized points
|
| 146 |
-
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
| 147 |
-
super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
| 148 |
-
|
| 149 |
-
# compute median depth
|
| 150 |
-
gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
|
| 151 |
-
pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
|
| 152 |
-
gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
|
| 153 |
-
pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
|
| 154 |
-
|
| 155 |
-
# subtract the median depth
|
| 156 |
-
gt_z1 -= gt_shift_z
|
| 157 |
-
gt_z2 -= gt_shift_z
|
| 158 |
-
pred_z1 -= pred_shift_z
|
| 159 |
-
pred_z2 -= pred_shift_z
|
| 160 |
-
|
| 161 |
-
# monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
|
| 162 |
-
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class Regr3D_ScaleInv (Regr3D):
|
| 166 |
-
""" Same than Regr3D but invariant to depth scale.
|
| 167 |
-
if gt_scale == True: enforce the prediction to take the same scale than GT
|
| 168 |
-
"""
|
| 169 |
-
|
| 170 |
-
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
| 171 |
-
# compute depth-normalized points
|
| 172 |
-
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
| 173 |
-
super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
| 174 |
-
|
| 175 |
-
# measure scene scale
|
| 176 |
-
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
|
| 177 |
-
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
|
| 178 |
-
|
| 179 |
-
# prevent predictions to be in a ridiculous range
|
| 180 |
-
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
|
| 181 |
-
|
| 182 |
-
# subtract the median depth
|
| 183 |
-
if self.gt_scale:
|
| 184 |
-
pred_pts1 *= gt_scale / pred_scale
|
| 185 |
-
pred_pts2 *= gt_scale / pred_scale
|
| 186 |
-
# monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
|
| 187 |
-
else:
|
| 188 |
-
gt_pts1 /= gt_scale
|
| 189 |
-
gt_pts2 /= gt_scale
|
| 190 |
-
pred_pts1 /= pred_scale
|
| 191 |
-
pred_pts2 /= pred_scale
|
| 192 |
-
# monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
|
| 193 |
-
|
| 194 |
-
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
|
| 198 |
-
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
|
| 199 |
-
pass
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def get_similarities(desc1, desc2, euc=False):
|
| 203 |
-
if euc: # euclidean distance in same range than similarities
|
| 204 |
-
dists = (desc1[:, :, None] - desc2[:, None]).norm(dim=-1)
|
| 205 |
-
sim = 1 / (1 + dists)
|
| 206 |
-
else:
|
| 207 |
-
# Compute similarities
|
| 208 |
-
sim = desc1 @ desc2.transpose(-2, -1)
|
| 209 |
-
return sim
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
class MatchingCriterion(BaseCriterion):
|
| 213 |
-
def __init__(self, reduction='mean', fp=torch.float32):
|
| 214 |
-
super().__init__(reduction)
|
| 215 |
-
self.fp = fp
|
| 216 |
-
|
| 217 |
-
def forward(self, a, b, valid_matches=None, euc=False):
|
| 218 |
-
assert a.ndim >= 2 and 1 <= a.shape[-1], f'Bad shape = {a.shape}'
|
| 219 |
-
dist = self.loss(a.to(self.fp), b.to(self.fp), valid_matches, euc=euc)
|
| 220 |
-
# one dimension less or reduction to single value
|
| 221 |
-
assert (valid_matches is None and dist.ndim == a.ndim -
|
| 222 |
-
1) or self.reduction in ['mean', 'sum', '1-mean', 'none']
|
| 223 |
-
if self.reduction == 'none':
|
| 224 |
-
return dist
|
| 225 |
-
if self.reduction == 'sum':
|
| 226 |
-
return dist.sum()
|
| 227 |
-
if self.reduction == 'mean':
|
| 228 |
-
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
| 229 |
-
if self.reduction == '1-mean':
|
| 230 |
-
return 1. - dist.mean() if dist.numel() > 0 else dist.new_ones(())
|
| 231 |
-
raise ValueError(f'bad {self.reduction=} mode')
|
| 232 |
-
|
| 233 |
-
def loss(self, a, b, valid_matches=None):
|
| 234 |
-
raise NotImplementedError
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
class InfoNCE(MatchingCriterion):
|
| 238 |
-
def __init__(self, temperature=0.07, eps=1e-8, mode='all', **kwargs):
|
| 239 |
-
super().__init__(**kwargs)
|
| 240 |
-
self.temperature = temperature
|
| 241 |
-
self.eps = eps
|
| 242 |
-
assert mode in ['all', 'proper', 'dual']
|
| 243 |
-
self.mode = mode
|
| 244 |
-
|
| 245 |
-
def loss(self, desc1, desc2, valid_matches=None, euc=False):
|
| 246 |
-
# valid positives are along diagonals
|
| 247 |
-
B, N, D = desc1.shape
|
| 248 |
-
B2, N2, D2 = desc2.shape
|
| 249 |
-
assert B == B2 and D == D2
|
| 250 |
-
if valid_matches is None:
|
| 251 |
-
valid_matches = torch.ones([B, N], dtype=bool)
|
| 252 |
-
# torch.all(valid_matches.sum(dim=-1) > 0) some pairs have no matches????
|
| 253 |
-
assert valid_matches.shape == torch.Size([B, N]) and valid_matches.sum() > 0
|
| 254 |
-
|
| 255 |
-
# Tempered similarities
|
| 256 |
-
sim = get_similarities(desc1, desc2, euc) / self.temperature
|
| 257 |
-
sim[sim.isnan()] = -torch.inf # ignore nans
|
| 258 |
-
# Softmax of positives with temperature
|
| 259 |
-
sim = sim.exp_() # save peak memory
|
| 260 |
-
positives = sim.diagonal(dim1=-2, dim2=-1)
|
| 261 |
-
|
| 262 |
-
# Loss
|
| 263 |
-
if self.mode == 'all': # Previous InfoNCE
|
| 264 |
-
loss = -torch.log((positives / sim.sum(dim=-1).sum(dim=-1, keepdim=True)).clip(self.eps))
|
| 265 |
-
elif self.mode == 'proper': # Proper InfoNCE
|
| 266 |
-
loss = -(torch.log((positives / sim.sum(dim=-2)).clip(self.eps)) +
|
| 267 |
-
torch.log((positives / sim.sum(dim=-1)).clip(self.eps)))
|
| 268 |
-
elif self.mode == 'dual': # Dual Softmax
|
| 269 |
-
loss = -(torch.log((positives**2 / sim.sum(dim=-1) / sim.sum(dim=-2)).clip(self.eps)))
|
| 270 |
-
else:
|
| 271 |
-
raise ValueError("This should not happen...")
|
| 272 |
-
return loss[valid_matches]
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
class APLoss (MatchingCriterion):
|
| 276 |
-
""" AP loss.
|
| 277 |
-
|
| 278 |
-
Input: (N, M) values in [min, max]
|
| 279 |
-
label: (N, M) values in {0, 1}
|
| 280 |
-
|
| 281 |
-
Returns: 1 - mAP (mean AP for each n in {1..N})
|
| 282 |
-
Note: typically, this is what you wanna minimize
|
| 283 |
-
"""
|
| 284 |
-
|
| 285 |
-
def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
|
| 286 |
-
super().__init__(**kw)
|
| 287 |
-
# Exact/True AP loss (not differentiable)
|
| 288 |
-
if nq == 0:
|
| 289 |
-
nq = 'sklearn' # special case
|
| 290 |
-
try:
|
| 291 |
-
self.compute_AP = eval('self.compute_true_AP_' + nq)
|
| 292 |
-
except:
|
| 293 |
-
raise ValueError("Unknown mode %s for AP loss" % nq)
|
| 294 |
-
|
| 295 |
-
@staticmethod
|
| 296 |
-
def compute_true_AP_sklearn(scores, labels):
|
| 297 |
-
def compute_AP(label, score):
|
| 298 |
-
return average_precision_score(label, score)
|
| 299 |
-
|
| 300 |
-
aps = scores.new_zeros((scores.shape[0], scores.shape[1]))
|
| 301 |
-
label_np = labels.cpu().numpy().astype(bool)
|
| 302 |
-
scores_np = scores.cpu().numpy()
|
| 303 |
-
for bi in range(scores_np.shape[0]):
|
| 304 |
-
for i in range(scores_np.shape[1]):
|
| 305 |
-
labels = label_np[bi, i, :]
|
| 306 |
-
if labels.sum() < 1:
|
| 307 |
-
continue
|
| 308 |
-
aps[bi, i] = compute_AP(labels, scores_np[bi, i, :])
|
| 309 |
-
return aps
|
| 310 |
-
|
| 311 |
-
@staticmethod
|
| 312 |
-
def compute_true_AP_torch(scores, labels):
|
| 313 |
-
assert scores.shape == labels.shape
|
| 314 |
-
B, N, M = labels.shape
|
| 315 |
-
dev = labels.device
|
| 316 |
-
with torch.no_grad():
|
| 317 |
-
# sort scores
|
| 318 |
-
_, order = scores.sort(dim=-1, descending=True)
|
| 319 |
-
# sort labels accordingly
|
| 320 |
-
labels = labels[torch.arange(B, device=dev)[:, None, None].expand(order.shape),
|
| 321 |
-
torch.arange(N, device=dev)[None, :, None].expand(order.shape),
|
| 322 |
-
order]
|
| 323 |
-
# compute number of positives per query
|
| 324 |
-
npos = labels.sum(dim=-1)
|
| 325 |
-
assert torch.all(torch.isclose(npos, npos[0, 0])
|
| 326 |
-
), "only implemented for constant number of positives per query"
|
| 327 |
-
npos = int(npos[0, 0])
|
| 328 |
-
# compute precision at each recall point
|
| 329 |
-
posrank = labels.nonzero()[:, -1].view(B, N, npos)
|
| 330 |
-
recall = torch.arange(1, 1 + npos, dtype=torch.float32, device=dev)[None, None, :].expand(B, N, npos)
|
| 331 |
-
precision = recall / (1 + posrank).float()
|
| 332 |
-
# average precision values at all recall points
|
| 333 |
-
aps = precision.mean(dim=-1)
|
| 334 |
-
|
| 335 |
-
return aps
|
| 336 |
-
|
| 337 |
-
def loss(self, desc1, desc2, valid_matches=None, euc=False): # if matches is None, positives are the diagonal
|
| 338 |
-
B, N1, D = desc1.shape
|
| 339 |
-
B2, N2, D2 = desc2.shape
|
| 340 |
-
assert B == B2 and D == D2
|
| 341 |
-
|
| 342 |
-
scores = get_similarities(desc1, desc2, euc)
|
| 343 |
-
|
| 344 |
-
labels = torch.zeros([B, N1, N2], dtype=scores.dtype, device=scores.device)
|
| 345 |
-
|
| 346 |
-
# allow all diagonal positives and only mask afterwards
|
| 347 |
-
labels.diagonal(dim1=-2, dim2=-1)[...] = 1.
|
| 348 |
-
apscore = self.compute_AP(scores, labels)
|
| 349 |
-
if valid_matches is not None:
|
| 350 |
-
apscore = apscore[valid_matches]
|
| 351 |
-
return apscore
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
class MatchingLoss (Criterion, MultiLoss):
|
| 355 |
-
"""
|
| 356 |
-
Matching loss per image
|
| 357 |
-
only compare pixels inside an image but not in the whole batch as what would be done usually
|
| 358 |
-
"""
|
| 359 |
-
|
| 360 |
-
def __init__(self, criterion, withconf=False, use_pts3d=False, negatives_padding=0, blocksize=4096):
|
| 361 |
-
super().__init__(criterion)
|
| 362 |
-
self.negatives_padding = negatives_padding
|
| 363 |
-
self.use_pts3d = use_pts3d
|
| 364 |
-
self.blocksize = blocksize
|
| 365 |
-
self.withconf = withconf
|
| 366 |
-
|
| 367 |
-
def add_negatives(self, outdesc2, desc2, batchid, x2, y2):
|
| 368 |
-
if self.negatives_padding:
|
| 369 |
-
B, H, W, D = desc2.shape
|
| 370 |
-
negatives = torch.ones([B, H, W], device=desc2.device, dtype=bool)
|
| 371 |
-
negatives[batchid, y2, x2] = False
|
| 372 |
-
sel = negatives & (negatives.view([B, -1]).cumsum(dim=-1).view(B, H, W)
|
| 373 |
-
<= self.negatives_padding) # take the N-first negatives
|
| 374 |
-
outdesc2 = torch.cat([outdesc2, desc2[sel].view([B, -1, D])], dim=1)
|
| 375 |
-
return outdesc2
|
| 376 |
-
|
| 377 |
-
def get_confs(self, pred1, pred2, sel1, sel2):
|
| 378 |
-
if self.withconf:
|
| 379 |
-
if self.use_pts3d:
|
| 380 |
-
outconfs1 = pred1['conf'][sel1]
|
| 381 |
-
outconfs2 = pred2['conf'][sel2]
|
| 382 |
-
else:
|
| 383 |
-
outconfs1 = pred1['desc_conf'][sel1]
|
| 384 |
-
outconfs2 = pred2['desc_conf'][sel2]
|
| 385 |
-
else:
|
| 386 |
-
outconfs1 = outconfs2 = None
|
| 387 |
-
return outconfs1, outconfs2
|
| 388 |
-
|
| 389 |
-
def get_descs(self, pred1, pred2):
|
| 390 |
-
if self.use_pts3d:
|
| 391 |
-
desc1, desc2 = pred1['pts3d'], pred2['pts3d_in_other_view']
|
| 392 |
-
else:
|
| 393 |
-
desc1, desc2 = pred1['desc'], pred2['desc']
|
| 394 |
-
return desc1, desc2
|
| 395 |
-
|
| 396 |
-
def get_matching_descs(self, gt1, gt2, pred1, pred2, **kw):
|
| 397 |
-
outdesc1 = outdesc2 = outconfs1 = outconfs2 = None
|
| 398 |
-
# Recover descs, GT corres and valid mask
|
| 399 |
-
desc1, desc2 = self.get_descs(pred1, pred2)
|
| 400 |
-
|
| 401 |
-
(x1, y1), (x2, y2) = gt1['corres'].unbind(-1), gt2['corres'].unbind(-1)
|
| 402 |
-
valid_matches = gt1['valid_corres']
|
| 403 |
-
|
| 404 |
-
# Select descs that have GT matches
|
| 405 |
-
B, N = x1.shape
|
| 406 |
-
batchid = torch.arange(B)[:, None].repeat(1, N) # B, N
|
| 407 |
-
outdesc1, outdesc2 = desc1[batchid, y1, x1], desc2[batchid, y2, x2] # B, N, D
|
| 408 |
-
|
| 409 |
-
# Padd with unused negatives
|
| 410 |
-
outdesc2 = self.add_negatives(outdesc2, desc2, batchid, x2, y2)
|
| 411 |
-
|
| 412 |
-
# Gather confs if needed
|
| 413 |
-
sel1 = batchid, y1, x1
|
| 414 |
-
sel2 = batchid, y2, x2
|
| 415 |
-
outconfs1, outconfs2 = self.get_confs(pred1, pred2, sel1, sel2)
|
| 416 |
-
|
| 417 |
-
return outdesc1, outdesc2, outconfs1, outconfs2, valid_matches, {'use_euclidean_dist': self.use_pts3d}
|
| 418 |
-
|
| 419 |
-
def blockwise_criterion(self, descs1, descs2, confs1, confs2, valid_matches, euc, rng=np.random, shuffle=True):
|
| 420 |
-
loss = None
|
| 421 |
-
details = {}
|
| 422 |
-
B, N, D = descs1.shape
|
| 423 |
-
|
| 424 |
-
if N <= self.blocksize: # Blocks are larger than provided descs, compute regular loss
|
| 425 |
-
loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
|
| 426 |
-
else: # Compute criterion on the blockdiagonal only, after shuffling
|
| 427 |
-
# Shuffle if necessary
|
| 428 |
-
matches_perm = slice(None)
|
| 429 |
-
if shuffle:
|
| 430 |
-
matches_perm = np.stack([rng.choice(range(N), size=N, replace=False) for _ in range(B)])
|
| 431 |
-
batchid = torch.tile(torch.arange(B), (N, 1)).T
|
| 432 |
-
matches_perm = batchid, matches_perm
|
| 433 |
-
|
| 434 |
-
descs1 = descs1[matches_perm]
|
| 435 |
-
descs2 = descs2[matches_perm]
|
| 436 |
-
valid_matches = valid_matches[matches_perm]
|
| 437 |
-
|
| 438 |
-
assert N % self.blocksize == 0, "Error, can't chunk block-diagonal, please check blocksize"
|
| 439 |
-
n_chunks = N // self.blocksize
|
| 440 |
-
descs1 = descs1.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
|
| 441 |
-
descs2 = descs2.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
|
| 442 |
-
valid_matches = valid_matches.view([B * n_chunks, self.blocksize])
|
| 443 |
-
loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
|
| 444 |
-
if self.withconf:
|
| 445 |
-
confs1, confs2 = map(lambda x: x[matches_perm], (confs1, confs2)) # apply perm to confidences if needed
|
| 446 |
-
|
| 447 |
-
if self.withconf:
|
| 448 |
-
# split confidences between positives/negatives for loss computation
|
| 449 |
-
details['conf_pos'] = map(lambda x: x[valid_matches.view(B, -1)], (confs1, confs2))
|
| 450 |
-
details['conf_neg'] = map(lambda x: x[~valid_matches.view(B, -1)], (confs1, confs2))
|
| 451 |
-
details['Conf1_std'] = confs1.std()
|
| 452 |
-
details['Conf2_std'] = confs2.std()
|
| 453 |
-
|
| 454 |
-
return loss, details
|
| 455 |
-
|
| 456 |
-
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
| 457 |
-
# Gather preds and GT
|
| 458 |
-
descs1, descs2, confs1, confs2, valid_matches, monitoring = self.get_matching_descs(
|
| 459 |
-
gt1, gt2, pred1, pred2, **kw)
|
| 460 |
-
|
| 461 |
-
# loss on matches
|
| 462 |
-
loss, details = self.blockwise_criterion(descs1, descs2, confs1, confs2,
|
| 463 |
-
valid_matches, euc=monitoring.pop('use_euclidean_dist', False))
|
| 464 |
-
|
| 465 |
-
details[type(self).__name__] = float(loss.mean())
|
| 466 |
-
return loss, (details | monitoring)
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
class ConfMatchingLoss(ConfLoss):
|
| 470 |
-
""" Weight matching by learned confidence. Same as ConfLoss but for a matching criterion
|
| 471 |
-
Assuming the input matching_loss is a match-level loss.
|
| 472 |
-
"""
|
| 473 |
-
|
| 474 |
-
def __init__(self, pixel_loss, alpha=1., confmode='prod', neg_conf_loss_quantile=False):
|
| 475 |
-
super().__init__(pixel_loss, alpha)
|
| 476 |
-
self.pixel_loss.withconf = True
|
| 477 |
-
self.confmode = confmode
|
| 478 |
-
self.neg_conf_loss_quantile = neg_conf_loss_quantile
|
| 479 |
-
|
| 480 |
-
def aggregate_confs(self, confs1, confs2): # get the confidences resulting from the two view predictions
|
| 481 |
-
if self.confmode == 'prod':
|
| 482 |
-
confs = confs1 * confs2 if confs1 is not None and confs2 is not None else 1.
|
| 483 |
-
elif self.confmode == 'mean':
|
| 484 |
-
confs = .5 * (confs1 + confs2) if confs1 is not None and confs2 is not None else 1.
|
| 485 |
-
else:
|
| 486 |
-
raise ValueError(f"Unknown conf mode {self.confmode}")
|
| 487 |
-
return confs
|
| 488 |
-
|
| 489 |
-
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
| 490 |
-
# compute per-pixel loss
|
| 491 |
-
loss, details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
|
| 492 |
-
# Recover confidences for positive and negative samples
|
| 493 |
-
conf1_pos, conf2_pos = details.pop('conf_pos')
|
| 494 |
-
conf1_neg, conf2_neg = details.pop('conf_neg')
|
| 495 |
-
conf_pos = self.aggregate_confs(conf1_pos, conf2_pos)
|
| 496 |
-
|
| 497 |
-
# weight Matching loss by confidence on positives
|
| 498 |
-
conf_pos, log_conf_pos = self.get_conf_log(conf_pos)
|
| 499 |
-
conf_loss = loss * conf_pos - self.alpha * log_conf_pos
|
| 500 |
-
# average + nan protection (in case of no valid pixels at all)
|
| 501 |
-
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
|
| 502 |
-
# Add negative confs loss to give some supervision signal to confidences for pixels that are not matched in GT
|
| 503 |
-
if self.neg_conf_loss_quantile:
|
| 504 |
-
conf_neg = torch.cat([conf1_neg, conf2_neg])
|
| 505 |
-
conf_neg, log_conf_neg = self.get_conf_log(conf_neg)
|
| 506 |
-
|
| 507 |
-
# recover quantile that will be used for negatives loss value assignment
|
| 508 |
-
neg_loss_value = torch.quantile(loss, self.neg_conf_loss_quantile).detach()
|
| 509 |
-
neg_loss = neg_loss_value * conf_neg - self.alpha * log_conf_neg
|
| 510 |
-
|
| 511 |
-
neg_loss = neg_loss.mean() if neg_loss.numel() > 0 else 0
|
| 512 |
-
conf_loss = conf_loss + neg_loss
|
| 513 |
-
|
| 514 |
-
return conf_loss, dict(matching_conf_loss=float(conf_loss), **details)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/model.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# MASt3R model class
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
import os
|
| 10 |
-
|
| 11 |
-
from mast3r.catmlp_dpt_head import mast3r_head_factory
|
| 12 |
-
|
| 13 |
-
import mast3r.utils.path_to_dust3r # noqa
|
| 14 |
-
from ..dust3r.dust3r.model import AsymmetricCroCo3DStereo # noqa
|
| 15 |
-
from ..dust3r.dust3r.utils.misc import transpose_to_landscape # noqa
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
inf = float('inf')
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def load_model(model_path, device, verbose=True):
|
| 22 |
-
if verbose:
|
| 23 |
-
print('... loading model from', model_path)
|
| 24 |
-
ckpt = torch.load(model_path, map_location='cpu')
|
| 25 |
-
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
| 26 |
-
if 'landscape_only' not in args:
|
| 27 |
-
args = args[:-1] + ', landscape_only=False)'
|
| 28 |
-
else:
|
| 29 |
-
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
|
| 30 |
-
assert "landscape_only=False" in args
|
| 31 |
-
if verbose:
|
| 32 |
-
print(f"instantiating : {args}")
|
| 33 |
-
net = eval(args)
|
| 34 |
-
s = net.load_state_dict(ckpt['model'], strict=False)
|
| 35 |
-
if verbose:
|
| 36 |
-
print(s)
|
| 37 |
-
return net.to(device)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
|
| 41 |
-
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
|
| 42 |
-
self.desc_mode = desc_mode
|
| 43 |
-
self.two_confs = two_confs
|
| 44 |
-
self.desc_conf_mode = desc_conf_mode
|
| 45 |
-
super().__init__(**kwargs)
|
| 46 |
-
|
| 47 |
-
@classmethod
|
| 48 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
| 49 |
-
if os.path.isfile(pretrained_model_name_or_path):
|
| 50 |
-
return load_model(pretrained_model_name_or_path, device='cpu')
|
| 51 |
-
else:
|
| 52 |
-
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
|
| 53 |
-
|
| 54 |
-
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
|
| 55 |
-
# assert img_size[0] % patch_size == 0 and img_size[
|
| 56 |
-
# 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
|
| 57 |
-
self.output_mode = output_mode
|
| 58 |
-
self.head_type = head_type
|
| 59 |
-
self.depth_mode = depth_mode
|
| 60 |
-
self.conf_mode = conf_mode
|
| 61 |
-
if self.desc_conf_mode is None:
|
| 62 |
-
self.desc_conf_mode = conf_mode
|
| 63 |
-
# allocate heads
|
| 64 |
-
self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 65 |
-
self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 66 |
-
# magic wrapper
|
| 67 |
-
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
|
| 68 |
-
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/utils/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
|
|
|
|
|
|
|
|
mast3r/utils/coarse_to_fine.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# coarse to fine utilities
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def crop_tag(cell):
|
| 11 |
-
return f'[{cell[1]}:{cell[3]},{cell[0]}:{cell[2]}]'
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def crop_slice(cell):
|
| 15 |
-
return slice(cell[1], cell[3]), slice(cell[0], cell[2])
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _start_pos(total_size, win_size, overlap):
|
| 19 |
-
# we must have AT LEAST overlap between segments
|
| 20 |
-
# first segment starts at 0, last segment starts at total_size-win_size
|
| 21 |
-
assert 0 <= overlap < 1
|
| 22 |
-
assert total_size >= win_size
|
| 23 |
-
spacing = win_size * (1 - overlap)
|
| 24 |
-
last_pt = total_size - win_size
|
| 25 |
-
n_windows = 2 + int((last_pt - 1) // spacing)
|
| 26 |
-
return np.linspace(0, last_pt, n_windows).round().astype(int)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def multiple_of_16(x):
|
| 30 |
-
return (x // 16) * 16
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _make_overlapping_grid(H, W, size, overlap):
|
| 34 |
-
H_win = multiple_of_16(H * size // max(H, W))
|
| 35 |
-
W_win = multiple_of_16(W * size // max(H, W))
|
| 36 |
-
x = _start_pos(W, W_win, overlap)
|
| 37 |
-
y = _start_pos(H, H_win, overlap)
|
| 38 |
-
grid = np.stack(np.meshgrid(x, y, indexing='xy'), axis=-1)
|
| 39 |
-
grid = np.concatenate((grid, grid + (W_win, H_win)), axis=-1)
|
| 40 |
-
return grid.reshape(-1, 4)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def _cell_size(cell2):
|
| 44 |
-
width, height = cell2[:, 2] - cell2[:, 0], cell2[:, 3] - cell2[:, 1]
|
| 45 |
-
assert width.min() >= 0
|
| 46 |
-
assert height.min() >= 0
|
| 47 |
-
return width, height
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _norm_windows(cell2, H2, W2, forced_resolution=None):
|
| 51 |
-
# make sure the window aspect ratio is 3/4, or the output resolution is forced_resolution if defined
|
| 52 |
-
outcell = cell2.copy()
|
| 53 |
-
width, height = _cell_size(cell2)
|
| 54 |
-
width2, height2 = width.clip(max=W2), height.clip(max=H2)
|
| 55 |
-
if forced_resolution is None:
|
| 56 |
-
width2[width < height] = (height2[width < height] * 3.01 / 4).clip(max=W2)
|
| 57 |
-
height2[width >= height] = (width2[width >= height] * 3.01 / 4).clip(max=H2)
|
| 58 |
-
else:
|
| 59 |
-
forced_H, forced_W = forced_resolution
|
| 60 |
-
width2[:] = forced_W
|
| 61 |
-
height2[:] = forced_H
|
| 62 |
-
|
| 63 |
-
half = (width2 - width) / 2
|
| 64 |
-
outcell[:, 0] -= half
|
| 65 |
-
outcell[:, 2] += half
|
| 66 |
-
half = (height2 - height) / 2
|
| 67 |
-
outcell[:, 1] -= half
|
| 68 |
-
outcell[:, 3] += half
|
| 69 |
-
|
| 70 |
-
# proj to integers
|
| 71 |
-
outcell = np.floor(outcell).astype(int)
|
| 72 |
-
# Take care of flooring errors
|
| 73 |
-
tmpw, tmph = _cell_size(outcell)
|
| 74 |
-
outcell[:, 0] += tmpw.astype(tmpw.dtype) - width2.astype(tmpw.dtype)
|
| 75 |
-
outcell[:, 1] += tmph.astype(tmpw.dtype) - height2.astype(tmpw.dtype)
|
| 76 |
-
|
| 77 |
-
# make sure 0 <= x < W2 and 0 <= y < H2
|
| 78 |
-
outcell[:, 0::2] -= outcell[:, [0]].clip(max=0)
|
| 79 |
-
outcell[:, 1::2] -= outcell[:, [1]].clip(max=0)
|
| 80 |
-
outcell[:, 0::2] -= outcell[:, [2]].clip(min=W2) - W2
|
| 81 |
-
outcell[:, 1::2] -= outcell[:, [3]].clip(min=H2) - H2
|
| 82 |
-
|
| 83 |
-
width, height = _cell_size(outcell)
|
| 84 |
-
assert np.all(width == width2.astype(width.dtype)) and np.all(
|
| 85 |
-
height == height2.astype(height.dtype)), "Error, output is not of the expected shape."
|
| 86 |
-
assert np.all(width <= W2)
|
| 87 |
-
assert np.all(height <= H2)
|
| 88 |
-
return outcell
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def _weight_pixels(cell, pix, assigned, gauss_var=2):
|
| 92 |
-
center = cell.reshape(-1, 2, 2).mean(axis=1)
|
| 93 |
-
width, height = _cell_size(cell)
|
| 94 |
-
|
| 95 |
-
# square distance between each cell center and each point
|
| 96 |
-
dist = (center[:, None] - pix[None]) / np.c_[width, height][:, None]
|
| 97 |
-
dist2 = np.square(dist).sum(axis=-1)
|
| 98 |
-
|
| 99 |
-
assert assigned.shape == dist2.shape
|
| 100 |
-
res = np.where(assigned, np.exp(-gauss_var * dist2), 0)
|
| 101 |
-
return res
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def pos2d_in_rect(p1, cell1):
|
| 105 |
-
x, y = p1.T
|
| 106 |
-
l, t, r, b = cell1
|
| 107 |
-
assigned = (l <= x) & (x < r) & (t <= y) & (y < b)
|
| 108 |
-
return assigned
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def _score_cell(cell1, H2, W2, p1, p2, min_corres=10, forced_resolution=None):
|
| 112 |
-
assert p1.shape == p2.shape
|
| 113 |
-
|
| 114 |
-
# compute keypoint assignment
|
| 115 |
-
assigned = pos2d_in_rect(p1, cell1[None].T)
|
| 116 |
-
assert assigned.shape == (len(cell1), len(p1))
|
| 117 |
-
|
| 118 |
-
# remove cells without correspondences
|
| 119 |
-
valid_cells = assigned.sum(axis=1) >= min_corres
|
| 120 |
-
cell1 = cell1[valid_cells]
|
| 121 |
-
assigned = assigned[valid_cells]
|
| 122 |
-
if not valid_cells.any():
|
| 123 |
-
return cell1, cell1, assigned
|
| 124 |
-
|
| 125 |
-
# fill-in the assigned points in both image
|
| 126 |
-
assigned_p1 = np.empty((len(cell1), len(p1), 2), dtype=np.float32)
|
| 127 |
-
assigned_p2 = np.empty((len(cell1), len(p2), 2), dtype=np.float32)
|
| 128 |
-
assigned_p1[:] = p1[None]
|
| 129 |
-
assigned_p2[:] = p2[None]
|
| 130 |
-
assigned_p1[~assigned] = np.nan
|
| 131 |
-
assigned_p2[~assigned] = np.nan
|
| 132 |
-
|
| 133 |
-
# find the median center and scale of assigned points in each cell
|
| 134 |
-
# cell_center1 = np.nanmean(assigned_p1, axis=1)
|
| 135 |
-
cell_center2 = np.nanmean(assigned_p2, axis=1)
|
| 136 |
-
im1_q25, im1_q75 = np.nanquantile(assigned_p1, (0.1, 0.9), axis=1)
|
| 137 |
-
im2_q25, im2_q75 = np.nanquantile(assigned_p2, (0.1, 0.9), axis=1)
|
| 138 |
-
|
| 139 |
-
robust_std1 = (im1_q75 - im1_q25).clip(20.)
|
| 140 |
-
robust_std2 = (im2_q75 - im2_q25).clip(20.)
|
| 141 |
-
|
| 142 |
-
cell_size1 = (cell1[:, 2:4] - cell1[:, 0:2])
|
| 143 |
-
cell_size2 = cell_size1 * robust_std2 / robust_std1
|
| 144 |
-
cell2 = np.c_[cell_center2 - cell_size2 / 2, cell_center2 + cell_size2 / 2]
|
| 145 |
-
|
| 146 |
-
# make sure cell bounds are valid
|
| 147 |
-
cell2 = _norm_windows(cell2, H2, W2, forced_resolution=forced_resolution)
|
| 148 |
-
|
| 149 |
-
# compute correspondence weights
|
| 150 |
-
corres_weights = _weight_pixels(cell1, p1, assigned) * _weight_pixels(cell2, p2, assigned)
|
| 151 |
-
|
| 152 |
-
# return a list of window pairs and assigned correspondences
|
| 153 |
-
return cell1, cell2, corres_weights
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def greedy_selection(corres_weights, target=0.9):
|
| 157 |
-
# corres_weight = (n_cell_pair, n_corres) matrix.
|
| 158 |
-
# If corres_weight[c,p]>0, means that correspondence p is visible in cell pair p
|
| 159 |
-
assert 0 < target <= 1
|
| 160 |
-
corres_weights = corres_weights.copy()
|
| 161 |
-
|
| 162 |
-
total = corres_weights.max(axis=0).sum()
|
| 163 |
-
target *= total
|
| 164 |
-
|
| 165 |
-
# init = empty
|
| 166 |
-
res = []
|
| 167 |
-
cur = np.zeros(corres_weights.shape[1]) # current selection
|
| 168 |
-
|
| 169 |
-
while cur.sum() < target:
|
| 170 |
-
# pick the nex best cell pair
|
| 171 |
-
best = corres_weights.sum(axis=1).argmax()
|
| 172 |
-
res.append(best)
|
| 173 |
-
|
| 174 |
-
# update current
|
| 175 |
-
cur += corres_weights[best]
|
| 176 |
-
# print('appending', best, 'with score', corres_weights[best].sum(), '-->', cur.sum())
|
| 177 |
-
|
| 178 |
-
# remove from all other views
|
| 179 |
-
corres_weights = (corres_weights - corres_weights[best]).clip(min=0)
|
| 180 |
-
|
| 181 |
-
return res
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def select_pairs_of_crops(img_q, img_b, pos2d_in_query, pos2d_in_ref, maxdim=512, overlap=.5, forced_resolution=None):
|
| 185 |
-
# prepare the overlapping cells
|
| 186 |
-
grid_q = _make_overlapping_grid(*img_q.shape[:2], maxdim, overlap)
|
| 187 |
-
grid_b = _make_overlapping_grid(*img_b.shape[:2], maxdim, overlap)
|
| 188 |
-
|
| 189 |
-
assert forced_resolution is None or len(forced_resolution) == 2
|
| 190 |
-
if isinstance(forced_resolution[0], int) or not len(forced_resolution[0]) == 2:
|
| 191 |
-
forced_resolution1 = forced_resolution2 = forced_resolution
|
| 192 |
-
else:
|
| 193 |
-
assert len(forced_resolution[1]) == 2
|
| 194 |
-
forced_resolution1 = forced_resolution[0]
|
| 195 |
-
forced_resolution2 = forced_resolution[1]
|
| 196 |
-
|
| 197 |
-
# Make sure crops respect constraints
|
| 198 |
-
grid_q = _norm_windows(grid_q.astype(float), *img_q.shape[:2], forced_resolution=forced_resolution1)
|
| 199 |
-
grid_b = _norm_windows(grid_b.astype(float), *img_b.shape[:2], forced_resolution=forced_resolution2)
|
| 200 |
-
|
| 201 |
-
# score cells
|
| 202 |
-
pairs_q = _score_cell(grid_q, *img_b.shape[:2], pos2d_in_query, pos2d_in_ref, forced_resolution=forced_resolution2)
|
| 203 |
-
pairs_b = _score_cell(grid_b, *img_q.shape[:2], pos2d_in_ref, pos2d_in_query, forced_resolution=forced_resolution1)
|
| 204 |
-
pairs_b = pairs_b[1], pairs_b[0], pairs_b[2] # cellq, cellb, corres_weights
|
| 205 |
-
|
| 206 |
-
# greedy selection until all correspondences are generated
|
| 207 |
-
cell1, cell2, corres_weights = map(np.concatenate, zip(pairs_q, pairs_b))
|
| 208 |
-
if len(corres_weights) == 0:
|
| 209 |
-
return # tolerated for empty generators
|
| 210 |
-
order = greedy_selection(corres_weights, target=0.9)
|
| 211 |
-
|
| 212 |
-
for i in order:
|
| 213 |
-
def pair_tag(qi, bi): return (str(qi) + crop_tag(cell1[i]), str(bi) + crop_tag(cell2[i]))
|
| 214 |
-
yield cell1[i], cell2[i], pair_tag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/utils/collate.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# Collate extensions
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import collections
|
| 10 |
-
from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format
|
| 11 |
-
from typing import Callable, Dict, Optional, Tuple, Type, Union, List
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def cat_collate_tensor_fn(batch, *, collate_fn_map):
|
| 15 |
-
return torch.cat(batch, dim=0)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
|
| 19 |
-
return [item for bb in batch for item in bb] # concatenate all lists
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
cat_collate_fn_map = default_collate_fn_map.copy()
|
| 23 |
-
cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn
|
| 24 |
-
cat_collate_fn_map[List] = cat_collate_list_fn
|
| 25 |
-
cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
|
| 29 |
-
r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """
|
| 30 |
-
elem = batch[0]
|
| 31 |
-
elem_type = type(elem)
|
| 32 |
-
|
| 33 |
-
if collate_fn_map is not None:
|
| 34 |
-
if elem_type in collate_fn_map:
|
| 35 |
-
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
|
| 36 |
-
|
| 37 |
-
for collate_type in collate_fn_map:
|
| 38 |
-
if isinstance(elem, collate_type):
|
| 39 |
-
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
|
| 40 |
-
|
| 41 |
-
if isinstance(elem, collections.abc.Mapping):
|
| 42 |
-
try:
|
| 43 |
-
return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
|
| 44 |
-
except TypeError:
|
| 45 |
-
# The mapping type may not support `__init__(iterable)`.
|
| 46 |
-
return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
|
| 47 |
-
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
| 48 |
-
return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
|
| 49 |
-
elif isinstance(elem, collections.abc.Sequence):
|
| 50 |
-
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
| 51 |
-
|
| 52 |
-
if isinstance(elem, tuple):
|
| 53 |
-
# Backwards compatibility.
|
| 54 |
-
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
|
| 55 |
-
else:
|
| 56 |
-
try:
|
| 57 |
-
return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
|
| 58 |
-
except TypeError:
|
| 59 |
-
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
| 60 |
-
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
|
| 61 |
-
|
| 62 |
-
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/utils/misc.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# utilitary functions for MASt3R
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import os
|
| 8 |
-
import hashlib
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def mkdir_for(f):
|
| 12 |
-
os.makedirs(os.path.dirname(f), exist_ok=True)
|
| 13 |
-
return f
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def hash_md5(s):
|
| 17 |
-
return hashlib.md5(s.encode('utf-8')).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mast3r/utils/path_to_dust3r.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# dust3r submodule import
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os.path as path
|
| 10 |
-
HERE_PATH = path.normpath(path.dirname(__file__))
|
| 11 |
-
DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../dust3r'))
|
| 12 |
-
DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r')
|
| 13 |
-
# check the presence of models directory in repo to be sure its cloned
|
| 14 |
-
if path.isdir(DUSt3R_LIB_PATH):
|
| 15 |
-
# workaround for sibling import
|
| 16 |
-
sys.path.insert(0, DUSt3R_REPO_PATH)
|
| 17 |
-
else:
|
| 18 |
-
raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n "
|
| 19 |
-
"Did you forget to run 'git submodule update --init --recursive' ?")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|