Spaces:
Running
on
Zero
Running
on
Zero
| from math import isqrt | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.cuda.amp import autocast | |
| from omegaconf import ListConfig | |
| from scenedino.common import util | |
| from scenedino.common.cameras.pinhole import outside_frustum, project_to_image, pts_into_camera | |
| class RaySampler: | |
| def __init__(self, z_near: float, z_far: float) -> None: | |
| self.z_near = z_near | |
| self.z_far = z_far | |
| def sample(self, images, poses, projs): | |
| raise NotImplementedError | |
| def reconstruct(self, render_dict): | |
| raise NotImplementedError | |
| class RandomRaySampler(RaySampler): | |
| def __init__( | |
| self, z_near: float, z_far: float, ray_batch_size: int, channels: int = 3 | |
| ) -> None: | |
| super().__init__(z_near, z_far) | |
| self.ray_batch_size = ray_batch_size | |
| self.channels = channels | |
| def sample(self, images, poses, projs, image_ids=None): | |
| n, v, c, h, w = images.shape | |
| all_rgb_gt = [] | |
| all_rays = [] | |
| for n_ in range(n): | |
| focals = projs[n_, :, [0, 1], [0, 1]] | |
| centers = projs[n_, :, [0, 1], [2, 2]] | |
| rays, xy = util.gen_rays( | |
| poses[n_].view(-1, 4, 4), | |
| w, | |
| h, | |
| focal=focals, | |
| c=centers, | |
| z_near=self.z_near, | |
| z_far=self.z_far, | |
| ) | |
| # Append frame id to the ray | |
| if image_ids is None: | |
| ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
| else: | |
| ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
| ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
| rays = torch.cat((rays, ids), dim=-1) | |
| rays = torch.cat((rays, xy), dim=-1) | |
| r_dim = rays.shape[-1] | |
| rays = rays.view(-1, r_dim) | |
| rgb_gt = images[n_].view(-1, c, h, w) | |
| rgb_gt = rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, c) | |
| pix_inds = torch.randint(0, v * h * w, (self.ray_batch_size,)) | |
| rgb_gt = rgb_gt[pix_inds] | |
| rays = rays[pix_inds] | |
| all_rgb_gt.append(rgb_gt) | |
| all_rays.append(rays) | |
| all_rgb_gt = torch.stack(all_rgb_gt) | |
| all_rays = torch.stack(all_rays) | |
| return all_rays, all_rgb_gt | |
| def reconstruct(self, render_dict, channels=None): | |
| for name, render_dict_part in render_dict.items(): | |
| if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
| continue | |
| if channels is None: | |
| channels = self.channels | |
| rgb = render_dict_part["rgb"] # n, n_pts, v * 3 | |
| depth = render_dict_part["depth"] | |
| invalid = render_dict_part["invalid"] | |
| rgb_gt = render_dict["rgb_gt"] | |
| n, n_pts, v_c = rgb.shape | |
| v = v_c // channels | |
| n_smps = invalid.shape[-2] | |
| render_dict_part["rgb"] = rgb.view(n, n_pts, v, channels) | |
| render_dict_part["depth"] = depth.view(n, n_pts) | |
| render_dict_part["invalid"] = invalid.view(n, n_pts, n_smps, v) | |
| if "invalid_features" in render_dict_part: | |
| invalid_features = render_dict_part["invalid_features"] | |
| render_dict_part["invalid_features"] = invalid_features.view(n, n_pts, n_smps, v) | |
| if "weights" in render_dict_part: | |
| weights = render_dict_part["weights"] | |
| render_dict_part["weights"] = weights.view(n, n_pts, n_smps) | |
| if "alphas" in render_dict_part: | |
| alphas = render_dict_part["alphas"] | |
| render_dict_part["alphas"] = alphas.view(n, n_pts, n_smps) | |
| if "z_samps" in render_dict_part: | |
| z_samps = render_dict_part["z_samps"] | |
| render_dict_part["z_samps"] = z_samps.view(n, n_pts, n_smps) | |
| if "rgb_samps" in render_dict_part: | |
| rgb_samps = render_dict_part["rgb_samps"] | |
| render_dict_part["rgb_samps"] = rgb_samps.view(n, n_pts, n_smps, v, channels) | |
| if "ray_info" in render_dict_part: | |
| ri_shape = render_dict_part["ray_info"].shape[-1] | |
| ray_info = render_dict_part["ray_info"] | |
| render_dict_part["ray_info"] = ray_info.view(n, n_pts, ri_shape) | |
| if "extras" in render_dict_part: | |
| extras_shape = render_dict_part["extras"].shape[-1] | |
| extras = render_dict_part["extras"] | |
| render_dict_part["extras"] = extras.view(n, n_pts, extras_shape) | |
| render_dict[name] = render_dict_part | |
| render_dict["rgb_gt"] = rgb_gt.view(n, n_pts, channels) | |
| return render_dict | |
| class PatchRaySampler(RaySampler): | |
| def __init__( | |
| self, | |
| z_near: float, | |
| z_far: float, | |
| ray_batch_size: int, | |
| patch_size: int, | |
| channels: int = 3, | |
| snap_to_grid: bool = False, | |
| dino_upscaled: bool = False, | |
| ) -> None: | |
| super().__init__(z_near, z_far) | |
| self.ray_batch_size = ray_batch_size | |
| self.channels = channels | |
| self.snap_to_grid = snap_to_grid | |
| self.dino_upscaled = dino_upscaled | |
| if isinstance(patch_size, int): | |
| self.patch_size_x, self.patch_size_y = patch_size, patch_size | |
| elif ( | |
| isinstance(patch_size, tuple) | |
| or isinstance(patch_size, list) | |
| or isinstance(patch_size, ListConfig) | |
| ): | |
| self.patch_size_y = patch_size[0] | |
| self.patch_size_x = patch_size[1] | |
| else: | |
| raise ValueError(f"Invalid format for patch size") | |
| assert (ray_batch_size % (self.patch_size_x * self.patch_size_y)) == 0 | |
| self._patch_count = self.ray_batch_size // ( | |
| self.patch_size_x * self.patch_size_y | |
| ) | |
| def sample( | |
| self, images, poses, projs, image_ids=None, dino_features=None, loss_feature_grid_shift=None, | |
| ): ### dim(images) == nv (ids_loss nv randomly sampled) | |
| n, v, c, h, w = images.shape | |
| self.channels = c | |
| if dino_features is not None: | |
| _, _, dino_channels, _, _ = dino_features.shape | |
| dino_features = dino_features.permute(0, 1, 3, 4, 2) | |
| device = images.device | |
| images = images.permute(0, 1, 3, 4, 2) | |
| all_rgb_gt, all_rays, all_dino_gt = [], [], [] | |
| for n_ in range(n): | |
| focals = projs[n_, :, [0, 1], [0, 1]] | |
| centers = projs[n_, :, [0, 1], [2, 2]] | |
| rays, xy = util.gen_rays( | |
| poses[n_].view(-1, 4, 4), | |
| w, | |
| h, | |
| focal=focals, | |
| c=centers, | |
| z_near=self.z_near, | |
| z_far=self.z_far, | |
| ) | |
| # Append frame id to the ray | |
| if image_ids is None: | |
| ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
| else: | |
| ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
| ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
| rays = torch.cat((rays, ids), dim=-1) | |
| rays = torch.cat((rays, xy), dim=-1) | |
| r_dim = rays.shape[-1] | |
| patch_coords_v = torch.randint(0, v, (self._patch_count,)) | |
| if self.snap_to_grid: | |
| if loss_feature_grid_shift is not None: | |
| patch_coords_y = torch.randint(0, h // self.patch_size_y - 1, (self._patch_count,)) | |
| patch_coords_x = torch.randint(0, w // self.patch_size_x - 1, (self._patch_count,)) | |
| else: | |
| patch_coords_y = torch.randint(0, h // self.patch_size_y, (self._patch_count,)) | |
| patch_coords_x = torch.randint(0, w // self.patch_size_x, (self._patch_count,)) | |
| else: | |
| patch_coords_y = torch.randint(0, h - self.patch_size_y, (self._patch_count,)) | |
| patch_coords_x = torch.randint(0, w - self.patch_size_x, (self._patch_count,)) | |
| sample_rgb_gt = [] | |
| sample_rays = [] | |
| sample_dino_gt = [] | |
| for v_, coord_y, coord_x in zip(patch_coords_v, patch_coords_y, patch_coords_x): | |
| if self.snap_to_grid: | |
| patch_y, patch_x = coord_y, coord_x | |
| if loss_feature_grid_shift is not None: | |
| y = (loss_feature_grid_shift[0] % self.patch_size_y) + self.patch_size_y * coord_y | |
| x = (loss_feature_grid_shift[1] % self.patch_size_x) + self.patch_size_x * coord_x | |
| if loss_feature_grid_shift[0] < 0: | |
| patch_y += 1 | |
| if loss_feature_grid_shift[1] < 0: | |
| patch_x += 1 | |
| else: | |
| y = self.patch_size_y * coord_y | |
| x = self.patch_size_x * coord_x | |
| else: | |
| raise NotImplementedError | |
| rgb_gt_patch = images[n_][ | |
| v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
| ].reshape(-1, self.channels) | |
| rays_patch = rays[ | |
| v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
| ].reshape(-1, r_dim) | |
| sample_rgb_gt.append(rgb_gt_patch) | |
| sample_rays.append(rays_patch) | |
| if dino_features is not None: | |
| if self.dino_upscaled: | |
| dino_gt_patch = dino_features[n_][ | |
| v_, y: y + self.patch_size_y, x: x + self.patch_size_x, : | |
| ].reshape(-1, dino_channels) | |
| else: | |
| dino_gt_patch = dino_features[n_][ | |
| v_, patch_y, patch_x, : | |
| ].reshape(-1, dino_channels) | |
| sample_dino_gt.append(dino_gt_patch) | |
| sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0) | |
| sample_rays = torch.cat(sample_rays, dim=0) | |
| all_rgb_gt.append(sample_rgb_gt) | |
| all_rays.append(sample_rays) | |
| if dino_features is not None: | |
| sample_dino_gt = torch.cat(sample_dino_gt, dim=0) | |
| all_dino_gt.append(sample_dino_gt) | |
| all_rgb_gt = torch.stack(all_rgb_gt) | |
| all_rays = torch.stack(all_rays) | |
| if dino_features is not None: | |
| all_dino_gt = torch.stack(all_dino_gt) | |
| return all_rays, all_rgb_gt, all_dino_gt | |
| else: | |
| return all_rays, all_rgb_gt | |
| def reconstruct(self, render_dict, channels=None, dino_channels=None): | |
| for name, render_dict_part in render_dict.items(): | |
| if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
| continue | |
| if channels is None: | |
| channels = self.channels | |
| rgb_gt = render_dict["rgb_gt"] | |
| dino_gt = render_dict["dino_gt"] | |
| n, n_pts, v_c = render_dict_part["rgb"].shape | |
| v = v_c // channels | |
| n_smps = render_dict_part["weights"].shape[-1] | |
| # (This can be a different v from the sample method) | |
| render_dict_part["rgb"] = render_dict_part["rgb"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, v, channels | |
| ) | |
| render_dict_part["weights"] = render_dict_part["weights"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
| ) | |
| render_dict_part["depth"] = render_dict_part["depth"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x | |
| ) | |
| render_dict_part["invalid"] = render_dict_part["invalid"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v | |
| ) | |
| # TODO: Figure out DINO invalid policy | |
| # if "invalid_features" in render_dict_part: | |
| # render_dict_part["invalid_features"] = render_dict_part["invalid_features"].view( | |
| # n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v | |
| # ) | |
| if "alphas" in render_dict_part: | |
| render_dict_part["alphas"] = render_dict_part["alphas"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
| ) | |
| if "z_samps" in render_dict_part: | |
| render_dict_part["z_samps"] = render_dict_part["z_samps"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
| ) | |
| if "rgb_samps" in render_dict_part: | |
| render_dict_part["rgb_samps"] = render_dict_part["rgb_samps"].view( | |
| n, | |
| self._patch_count, | |
| self.patch_size_y, | |
| self.patch_size_x, | |
| n_smps, | |
| v, | |
| channels, | |
| ) | |
| if "ray_info" in render_dict_part: | |
| ri_shape = render_dict_part["ray_info"].shape[-1] | |
| render_dict_part["ray_info"] = render_dict_part["ray_info"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, ri_shape | |
| ) | |
| if "extras" in render_dict_part: | |
| extras_shape = render_dict_part["extras"].shape[-1] | |
| render_dict_part["extras"] = render_dict_part["extras"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, extras_shape | |
| ) | |
| if "dino_features" in render_dict_part: | |
| dino_shape = render_dict_part["dino_features"].shape[-1] | |
| render_dict_part["dino_features"] = render_dict_part["dino_features"].view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, 1, dino_shape | |
| ) | |
| render_dict[name] = render_dict_part | |
| render_dict["rgb_gt"] = rgb_gt.view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, channels | |
| ) | |
| dino_gt_shape = dino_gt.shape[-1] | |
| if self.dino_upscaled: | |
| render_dict["dino_gt"] = dino_gt.view( | |
| n, self._patch_count, self.patch_size_y, self.patch_size_x, dino_gt_shape | |
| ) | |
| else: | |
| render_dict["dino_gt"] = dino_gt.view( | |
| n, self._patch_count, dino_gt_shape | |
| ) | |
| if "dino_artifacts" in render_dict: | |
| render_dict["dino_artifacts"] = render_dict["dino_artifacts"].view( | |
| n, self._patch_count, dino_gt_shape | |
| ) | |
| return render_dict | |
| class PointBasedRaySampler(RandomRaySampler): | |
| def sample( | |
| self, images, poses, projs, xyz, image_ids=None | |
| ): ### dim(images) == nv (ids_loss nv randomly sampled) | |
| n, v, c, h, w = images.shape | |
| assert v == 1 | |
| with autocast(enabled=False): | |
| poses_w2c = torch.inverse(poses) | |
| inv_K = torch.inverse(projs[:, :, :3, :3]) | |
| B, n_pts, _ = xyz.shape | |
| xyz_projected = pts_into_camera(xyz, poses_w2c) | |
| distance = torch.norm(xyz_projected, dim=-2, keepdim=True) | |
| xy, z = project_to_image(xyz_projected, projs) | |
| xy = xy[:, 0] | |
| z = z[:, 0] | |
| distance = distance[:, 0, 0, :, None] | |
| # For numerical stability with AMP. Should not affect training outcome | |
| xy = xy.clamp(-2, 2) | |
| # Build rays | |
| cam_centers = poses[:, 0, None, :3, 3].expand(-1, n_pts, -1) | |
| ray_dir = ((poses[:, 0, :3, :3] @ inv_K[:, 0]) @ torch.cat((xy, torch.ones_like(xy[..., :1])), dim=-1).permute(0, 2, 1)).permute(0, 2, 1) | |
| cam_nears = torch.ones_like(cam_centers[..., :1]) * self.z_near | |
| cam_fars = torch.ones_like(cam_centers[..., :1]) * self.z_far | |
| ids = torch.zeros_like(cam_nears) | |
| rays = torch.cat((cam_centers, ray_dir, cam_nears, cam_fars, ids, xy, distance), dim=-1) | |
| rgb_gt = F.grid_sample(images[:, 0], xy.reshape(n, -1, 1, 2).to(images.dtype), padding_mode="border", align_corners=False)[..., 0].permute(0, 2, 1) | |
| return rays, rgb_gt | |
| class ImageRaySampler(RaySampler): | |
| def __init__( | |
| self, | |
| z_near: float, | |
| z_far: float, | |
| height: int | None = None, | |
| width: int | None = None, | |
| channels: int = 3, | |
| norm_dir: bool = True, | |
| dino_upscaled: bool = False, | |
| ) -> None: | |
| super().__init__(z_near, z_far) | |
| self.height = height | |
| self.width = width | |
| self.channels = channels | |
| self.norm_dir = norm_dir | |
| self.dino_upscaled = dino_upscaled | |
| def sample(self, images, poses, projs, image_ids=None, dino_features=None, dino_artifacts=None): | |
| n, v, _, _ = poses.shape | |
| device = poses.device | |
| dtype = poses.dtype | |
| if images is not None: | |
| self.channels = images.shape[2] | |
| if self.height is None: | |
| self.height, self.width = images.shape[-2:] | |
| if dino_features is not None: | |
| _, _, dino_channels, _, _ = dino_features.shape | |
| h = self.height | |
| w = self.width | |
| all_rgb_gt = [] | |
| all_dino_gt = [] | |
| all_rays = [] | |
| for n_ in range(n): | |
| focals = projs[n_, :, [0, 1], [0, 1]] | |
| centers = projs[n_, :, [0, 1], [2, 2]] | |
| rays, xy = util.gen_rays( | |
| poses[n_].view(-1, 4, 4), | |
| self.width, | |
| self.height, | |
| focal=focals, | |
| c=centers, | |
| z_near=self.z_near, | |
| z_far=self.z_far, | |
| norm_dir=self.norm_dir, | |
| ) | |
| # Append frame id to the ray | |
| if image_ids is None: | |
| ids = torch.arange(v, device=device, dtype=dtype) | |
| else: | |
| ids = torch.tensor(image_ids, device=device, dtype=dtype) | |
| ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
| rays = torch.cat((rays, ids), dim=-1) | |
| rays = torch.cat((rays, xy), dim=-1) | |
| r_dim = rays.shape[-1] | |
| rays = rays.view(-1, r_dim) | |
| all_rays.append(rays) | |
| if images is not None: | |
| rgb_gt = images[n_].view(-1, self.channels, self.height, self.width) | |
| rgb_gt = ( | |
| rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, self.channels) | |
| ) | |
| all_rgb_gt.append(rgb_gt) | |
| if dino_features is not None: | |
| patch_h, patch_w = dino_features[n_].shape[-2], dino_features[n_].shape[-1] | |
| dino_gt = dino_features[n_].view(-1, dino_channels, patch_h, patch_w) | |
| dino_gt = ( | |
| dino_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, dino_channels) | |
| ) | |
| all_dino_gt.append(dino_gt) | |
| all_rays = torch.stack(all_rays) | |
| if images is not None: | |
| all_rgb_gt = torch.stack(all_rgb_gt) | |
| else: | |
| all_rgb_gt = None | |
| if dino_features is not None: | |
| all_dino_gt = torch.stack(all_dino_gt) | |
| return all_rays, all_rgb_gt, all_dino_gt | |
| else: | |
| return all_rays, all_rgb_gt | |
| def reconstruct(self, render_dict, channels=None, dino_channels=None): | |
| for name, render_dict_part in render_dict.items(): | |
| if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
| continue | |
| if channels is None: | |
| channels = self.channels | |
| rgb = render_dict_part["rgb"] # n, n_pts, v * 3 | |
| weights = render_dict_part["weights"] | |
| depth = render_dict_part["depth"] | |
| invalid = render_dict_part["invalid"] | |
| n, n_pts, v_c = rgb.shape | |
| v_in = n_pts // (self.height * self.width) | |
| v_render = v_c // channels | |
| n_smps = weights.shape[-1] | |
| # (This can be a different v from the sample method) | |
| render_dict_part["rgb"] = rgb.view(n, v_in, self.height, self.width, v_render, channels) | |
| render_dict_part["weights"] = weights.view(n, v_in, self.height, self.width, n_smps) | |
| render_dict_part["depth"] = depth.view(n, v_in, self.height, self.width) | |
| render_dict_part["invalid"] = invalid.view( | |
| n, v_in, self.height, self.width, n_smps, v_render | |
| ) | |
| if "invalid_features" in render_dict_part: | |
| invalid_features = render_dict_part["invalid_features"] | |
| render_dict_part["invalid_features"] = invalid_features.view( | |
| n, v_in, self.height, self.width, n_smps, v_render | |
| ) | |
| if "alphas" in render_dict_part: | |
| alphas = render_dict_part["alphas"] | |
| render_dict_part["alphas"] = alphas.view(n, v_in, self.height, self.width, n_smps) | |
| if "z_samps" in render_dict_part: | |
| z_samps = render_dict_part["z_samps"] | |
| render_dict_part["z_samps"] = z_samps.view( | |
| n, v_in, self.height, self.width, n_smps | |
| ) | |
| if "rgb_samps" in render_dict_part: | |
| rgb_samps = render_dict_part["rgb_samps"] | |
| render_dict_part["rgb_samps"] = rgb_samps.view( | |
| n, v_in, self.height, self.width, n_smps, v_render, channels | |
| ) | |
| if "ray_info" in render_dict_part: | |
| ri_shape = render_dict_part["ray_info"].shape[-1] | |
| ray_info = render_dict_part["ray_info"] | |
| render_dict_part["ray_info"] = ray_info.view(n, v_in, self.height, self.width, ri_shape) | |
| if "extras" in render_dict_part: | |
| ex_shape = render_dict_part["extras"].shape[-1] | |
| extras = render_dict_part["extras"] | |
| render_dict_part["extras"] = extras.view(n, v_in, self.height, self.width, ex_shape) | |
| if "dino_features" in render_dict_part: | |
| dino_shape = render_dict_part["dino_features"].shape[-1] | |
| dino = render_dict_part["dino_features"] | |
| render_dict_part["dino_features"] = dino.view(n, v_in, self.height, self.width, 1, dino_shape) | |
| render_dict[name] = render_dict_part | |
| if "rgb_gt" in render_dict: | |
| rgb_gt = render_dict["rgb_gt"] | |
| render_dict["rgb_gt"] = rgb_gt.view( | |
| n, v_in, self.height, self.width, channels | |
| ) | |
| if "dino_gt" in render_dict: | |
| dino_gt = render_dict["dino_gt"] | |
| dino_gt_shape = dino_gt.shape[-1] | |
| if self.dino_upscaled: | |
| render_dict["dino_gt"] = dino_gt.view( | |
| n, v_in, self.height, self.width, dino_gt_shape | |
| ) | |
| else: | |
| # TODO: patch size should not be inferred like this, but parameter | |
| patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_gt.numel()) | |
| render_dict["dino_gt"] = dino_gt.view( | |
| n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape | |
| ) | |
| if "dino_artifacts" in render_dict: | |
| dino_artifacts = render_dict["dino_artifacts"] | |
| # TODO: patch size should not be inferred like this, but parameter | |
| patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_artifacts.numel()) | |
| render_dict["dino_artifacts"] = dino_artifacts.view( | |
| n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape | |
| ) | |
| return render_dict | |
| class JitteredPatchRaySampler(PatchRaySampler): | |
| def __init__( | |
| self, | |
| z_near: float, | |
| z_far: float, | |
| ray_batch_size: int, | |
| patch_size: int, | |
| jitter_strength: float, # In pixels, max [0, 1) | |
| channels: int = 3, | |
| ) -> None: | |
| super().__init__(z_near, z_far, ray_batch_size, patch_size, channels) | |
| assert 0 <= jitter_strength < 1, "Jitter strength is invalid." | |
| self.jitter_strength = jitter_strength | |
| x = torch.arange(0, self.patch_size_x).view(1, 1, -1, 1).expand(-1, self.patch_size_y, -1, -1) | |
| y = torch.arange(0, self.patch_size_y).view(1, -1, 1, 1).expand(-1, -1, self.patch_size_x, -1) | |
| self._grid = torch.cat((x, y), dim=-1) | |
| def sample( | |
| self, images, poses, projs, image_ids=None | |
| ): ### dim(images) == nv (ids_loss nv randomly sampled) | |
| n, v, c, h, w = images.shape | |
| device = images.device | |
| all_rgb_gt, all_rays = [], [] | |
| xy_offset = ((torch.rand(2) - .5) * self.jitter_strength) | |
| for n_ in range(n): | |
| focals = projs[n_, :, [0, 1], [0, 1]] | |
| centers = projs[n_, :, [0, 1], [2, 2]] | |
| rays, xy = util.gen_rays( | |
| poses[n_].view(-1, 4, 4), | |
| w, | |
| h, | |
| focal=focals, | |
| c=centers, | |
| z_near=self.z_near, | |
| z_far=self.z_far, | |
| xy_offset=xy_offset, | |
| ) | |
| # Append frame id to the ray | |
| if image_ids is None: | |
| ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
| else: | |
| ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
| ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
| rays = torch.cat((rays, ids), dim=-1) | |
| r_dim = rays.shape[-1] | |
| patch_coords_v = torch.randint(0, v, (self._patch_count,)) | |
| patch_coords_y = torch.randint( | |
| 0, h - self.patch_size_y, (self._patch_count,) | |
| ) | |
| patch_coords_x = torch.randint( | |
| 0, w - self.patch_size_x, (self._patch_count,) | |
| ) | |
| sample_rgb_gt = [] | |
| sample_rays = [] | |
| for v_, y, x in zip(patch_coords_v, patch_coords_y, patch_coords_x): | |
| xy_patch = torch.tensor((x, y)).view(1, 1, 1, 2) | |
| patch_grid = self._grid + xy_patch + xy_offset.view(1, 1, 1, 2) + .5 | |
| patch_grid = patch_grid.to(images.device) | |
| patch_grid[..., 0] = (patch_grid[..., 0] / w) * 2 - 1 | |
| patch_grid[..., 1] = (patch_grid[..., 1] / h) * 2 - 1 | |
| rgb_gt_patch = F.grid_sample(images[n_:n_+1, v_], patch_grid, padding_mode="border", align_corners=False) | |
| rgb_gt_patch = rgb_gt_patch.permute(0, 2, 3, 1).reshape(-1, self.channels) | |
| rays_patch = rays[ | |
| v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
| ].reshape(-1, r_dim) | |
| sample_rgb_gt.append(rgb_gt_patch) | |
| sample_rays.append(rays_patch) | |
| sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0) | |
| sample_rays = torch.cat(sample_rays, dim=0) | |
| all_rgb_gt.append(sample_rgb_gt) | |
| all_rays.append(sample_rays) | |
| all_rgb_gt = torch.stack(all_rgb_gt) | |
| all_rays = torch.stack(all_rays) | |
| return all_rays, all_rgb_gt | |
| def get_ray_sampler(config) -> RaySampler: | |
| z_near = config["z_near"] | |
| z_far = config["z_far"] | |
| sample_mode = config.get("sample_mode", "random") | |
| # TODO: check channel size | |
| match sample_mode: | |
| case "random": | |
| return RandomRaySampler(z_near, z_far, **config["args"]) | |
| case "patch": | |
| return PatchRaySampler(z_near, z_far, **config["args"]) | |
| case "jitteredpatch": | |
| return JitteredPatchRaySampler(z_near, z_far, **config["args"]) | |
| case "image": | |
| return ImageRaySampler(z_near, z_far) | |
| case _: | |
| raise NotImplementedError | |