Spaces:
Running
on
Zero
Running
on
Zero
| from math import sin, cos | |
| import torch | |
| from torch.cuda.amp import autocast | |
| def transform_pts(pts: torch.Tensor, rel_pose: torch.Tensor) -> torch.Tensor: | |
| """Transform points by relative pose | |
| Args: | |
| pts (torch.Tensor): B, n_pts, 3 | |
| rel_pose (torch.Tensor): B, 4, 4 | |
| Returns: | |
| torch.Tensor: B, n_pts, 3 | |
| """ | |
| pts = torch.cat((pts, torch.ones_like(pts[..., :1])), dim=-1) | |
| return (pts @ rel_pose.transpose(-1, -2))[..., :3] | |
| # TODO: unify | |
| def distance_to_z(depths: torch.Tensor, projs: torch.Tensor): | |
| n, nv, h, w = depths.shape | |
| device = depths.device | |
| inv_K = torch.inverse(projs) | |
| grid_x = ( | |
| torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).expand(-1, -1, h, -1) | |
| ) | |
| grid_y = ( | |
| torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).expand(-1, -1, -1, w) | |
| ) | |
| img_points = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=2).expand( | |
| n, nv, -1, -1, -1 | |
| ) | |
| cam_points = (inv_K @ img_points.view(n, nv, 3, -1)).view(n, nv, 3, h, w) | |
| factors = cam_points[:, :, 2, :, :] / torch.norm(cam_points, dim=2) | |
| return depths * factors | |
| def z_to_distance(z: torch.Tensor, projs: torch.Tensor): | |
| n, nv, h, w = z.shape | |
| device = z.device | |
| inv_K = torch.inverse(projs) | |
| grid_x = ( | |
| torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).expand(-1, -1, h, -1) | |
| ) | |
| grid_y = ( | |
| torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).expand(-1, -1, -1, w) | |
| ) | |
| img_points = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=2).expand( | |
| n, nv, -1, -1, -1 | |
| ) | |
| cam_points = (inv_K @ img_points.view(n, nv, 3, -1)).view(n, nv, 3, h, w) | |
| factors = cam_points[:, :, 2, :, :] / torch.norm(cam_points, dim=2) | |
| return z / factors | |
| def azimuth_elevation_to_rotation(azimuth: float, elevation: float) -> torch.Tensor: | |
| rot_z = torch.tensor( | |
| [ | |
| [cos(azimuth), -sin(azimuth), 0.0], | |
| [sin(azimuth), cos(azimuth), 0.0], | |
| [0.0, 0.0, 1.0], | |
| ] | |
| ) | |
| rot_x = torch.tensor( | |
| [ | |
| [1.0, 0.0, 0.0], | |
| [0.0, cos(azimuth), -sin(azimuth)], | |
| [0.0, sin(azimuth), cos(azimuth)], | |
| ] | |
| ) | |
| return rot_x @ rot_z | |
| def estimate_frustum_overlap(proj_source: torch.Tensor, pose_source: torch.Tensor, proj_target: torch.Tensor, pose_target: torch.Tensor, dist_lim=50): | |
| device = proj_source.device | |
| dtype = proj_source.dtype | |
| # Check which camera has higher z value in target coordinate system | |
| with autocast(enabled=False): | |
| src2tgt = torch.inverse(pose_target) @ pose_source | |
| for i in range(len(src2tgt)): | |
| if src2tgt[i, 2, 3] < 0: | |
| print("SWAP", i) | |
| proj_ = proj_target[i].clone() | |
| pose_ = pose_target[i].clone() | |
| proj_target[i] = proj_source[i] | |
| pose_target[i] = pose_source[i] | |
| proj_source[i] = proj_ | |
| pose_source[i] = pose_ | |
| points = torch.tensor([[ | |
| [-1, 1, 1, 1], | |
| [1, 1, 1, 1], | |
| [1, -1, 1, 1], | |
| [-1, -1, 1, 1], | |
| ]], device=device, dtype=dtype) | |
| with autocast(enabled=False): | |
| K_src_inv = torch.inverse(proj_source) | |
| K_tgt_inv = torch.inverse(proj_target) | |
| _ = K_src_inv.new_zeros(K_src_inv.shape[0], 4, 4) | |
| _[:, 3, 3] = 1 | |
| _[:, :3, :3] = K_src_inv | |
| K_src_inv = _ | |
| _ = K_tgt_inv.new_zeros(K_tgt_inv.shape[0], 4, 4) | |
| _[:, 3, 3] = 1 | |
| _[:, :3, :3] = K_tgt_inv | |
| K_tgt_inv = _ | |
| points_src = K_src_inv @ points.permute(0, 2, 1) | |
| points_tgt = K_tgt_inv @ points.permute(0, 2, 1) | |
| normals_tgt = torch.cross(points_tgt[..., :3, :], torch.roll(points_tgt[..., :3, :], shifts=-1, dims=-2), dim=-2) | |
| normals_tgt = normals_tgt / torch.norm(normals_tgt, dim=-2, keepdim=True) | |
| with autocast(enabled=False): | |
| src2tgt = torch.inverse(pose_target) @ pose_source | |
| base = src2tgt[:, :3, 3, None] | |
| points_src_tgt = src2tgt @ points_src | |
| dirs = points_src_tgt[..., :3, :] - base | |
| # dirs = dirs / torch.norm(dirs, dim=-2) #dirs should have z length 1 | |
| dists = - (base[..., None] * normals_tgt[..., None, :]).sum(dim=-3) / (dirs[..., None] * normals_tgt[..., None, :]).sum(dim=-3).clamp_min(1e-4) | |
| # print(dists) | |
| # Ignore all non-positive | |
| mask = (dists <= 0) | (dists > dist_lim) | |
| dists[mask] = dist_lim | |
| # print(dists) | |
| dists = torch.min(dists, dim=-1)[0] | |
| mean_dist = dists.mean(dim=-1) | |
| # print(mean_dist, (torch.max(points_src[..., 0], dim=-1)[0] - torch.min(points_src[..., 0], dim=-1)[0]), (torch.max(points_src[..., 1], dim=-1)[0] - torch.min(points_src[..., 1], dim=-1)[0])) | |
| volume_estimate = \ | |
| 1/3 * \ | |
| (torch.max(points_src[..., 0], dim=-1)[0] - torch.min(points_src[..., 0], dim=-1)[0]) * mean_dist * \ | |
| (torch.max(points_src[..., 1], dim=-1)[0] - torch.min(points_src[..., 1], dim=-1)[0]) * mean_dist * \ | |
| mean_dist | |
| return volume_estimate | |
| def estimate_frustum_overlap_2(proj_source: torch.Tensor, pose_source: torch.Tensor, proj_target: torch.Tensor, pose_target: torch.Tensor, z_range=(3, 40), res=(8, 8, 16)): | |
| device = proj_source.device | |
| dtype = proj_source.dtype | |
| with autocast(enabled=False): | |
| K_src_inv = torch.inverse(proj_source) | |
| n = proj_source.shape[0] | |
| w, h, d = res | |
| pixel_width = 2 / w | |
| pixel_height = 2 / h | |
| x = torch.linspace(-1 + .5 * pixel_width, 1 - .5 * pixel_width, w, dtype=dtype, device=device).view(1, 1, 1, w).expand(n, d, h, w) | |
| y = torch.linspace(-1 + .5 * pixel_height, 1 - .5 * pixel_height, h, dtype=dtype, device=device).view(1, 1, h, 1).expand(n, d, h, w) | |
| z = torch.ones_like(x) | |
| xyz = torch.stack((x, y, z), dim=-1) | |
| xyz = K_src_inv @ xyz.reshape(n, -1, 3).permute(0, 2, 1) | |
| xyz = xyz.reshape(n, 3, d, h, w) | |
| # xyz = xyz * (1 / torch.linspace(1 / z_range[0], 1 / z_range[1], d, dtype=dtype, device=device).view(1, 1, d, 1, 1).expand(n, 1, d, h, w)) | |
| xyz = xyz * torch.linspace(z_range[0], z_range[1], d, dtype=dtype, device=device).view(1, 1, d, 1, 1).expand(n, 1, d, h, w) | |
| xyz = torch.cat((xyz, torch.ones_like(xyz[:, :1])), dim=1) | |
| xyz = xyz.reshape(n, 4, -1) | |
| with autocast(enabled=False): | |
| src2tgt = torch.inverse(pose_target) @ pose_source | |
| xyz = src2tgt @ xyz | |
| # print(xyz) | |
| xyz = proj_target @ xyz[:, :3, :] | |
| xyz[:, :2] = xyz[:, :2] / xyz[:, 2:3, :] | |
| # print(xyz) | |
| valid = (xyz[:, 0].abs() < 1) & (xyz[:, 1].abs() < 1) & (xyz[:, 2].abs() > z_range[0])# & (xyz[:, 2].abs() < z_range[1]) | |
| # print(valid) | |
| volume_estimate = valid.to(dtype).mean(-1) | |
| return volume_estimate | |
| def compute_occlusions(flow0, flow1): | |
| n, _, h, w = flow0.shape | |
| device = flow0.device | |
| x = torch.linspace(-1, 1, w, device=device).view(1, 1, w).expand(1, h, w) | |
| y = torch.linspace(-1, 1, h, device=device).view(1, h, 1).expand(1, h, w) | |
| xy = torch.cat((x, y), dim=0).view(1, 2, h, w).expand(n, 2, h, w) | |
| flow0_r = torch.cat((flow0[:, 0:1, :, :] * 2 / w , flow0[:, 1:2, :, :] * 2 / h), dim=1) | |
| flow1_r = torch.cat((flow1[:, 0:1, :, :] * 2 / w , flow1[:, 1:2, :, :] * 2 / h), dim=1) | |
| xy_0 = xy + flow0_r | |
| xy_1 = xy + flow1_r | |
| xy_0 = xy_0.view(n, 2, -1) | |
| xy_1 = xy_1.view(n, 2, -1) | |
| ns = torch.arange(n, device=device, dtype=xy_0.dtype) | |
| nxy_0 = torch.cat((ns.view(n, 1, 1).expand(-1, 1, xy_0.shape[-1]), xy_0), dim=1) | |
| nxy_1 = torch.cat((ns.view(n, 1, 1).expand(-1, 1, xy_1.shape[-1]), xy_1), dim=1) | |
| mask0 = torch.zeros_like(flow0[:, :1, :, :]) | |
| mask0[nxy_1[:, 0, :].long(), 0, ((nxy_1[:, 2, :] * .5 + .5) * h).round().long().clamp(0, h-1), ((nxy_1[:, 1, :] * .5 + .5) * w).round().long().clamp(0, w-1)] = 1 | |
| mask1 = torch.zeros_like(flow1[:, :1, :, :]) | |
| mask1[nxy_0[:, 0, :].long(), 0, ((nxy_0[:, 2, :] * .5 + .5) * h).round().long().clamp(0, h-1), ((nxy_0[:, 1, :] * .5 + .5) * w).round().long().clamp(0, w-1)] = 1 | |
| return mask0, mask1 |