Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| NeRF differentiable renderer. | |
| References: | |
| https://github.com/bmild/nerf | |
| https://github.com/kwea123/nerf_pl | |
| """ | |
| import torch | |
| import torch.autograd.profiler as profiler | |
| from dotmap import DotMap | |
| class _RenderWrapper(torch.nn.Module): | |
| def __init__(self, net, renderer, simple_output): | |
| super().__init__() | |
| self.net = net | |
| self.renderer = renderer | |
| self.simple_output = simple_output | |
| def forward( | |
| self, | |
| rays, | |
| want_weights=False, | |
| want_alphas=False, | |
| want_z_samps=False, | |
| want_rgb_samps=False, | |
| sample_from_dist=None, | |
| ): | |
| if rays.shape[0] == 0: | |
| return ( | |
| torch.zeros(0, 3, device=rays.device), | |
| torch.zeros(0, device=rays.device), | |
| ) | |
| outputs = self.renderer( | |
| self.net, | |
| rays, | |
| want_weights=want_weights and not self.simple_output, | |
| want_alphas=want_alphas and not self.simple_output, | |
| want_z_samps=want_z_samps and not self.simple_output, | |
| want_rgb_samps=want_rgb_samps and not self.simple_output, | |
| sample_from_dist=sample_from_dist, | |
| ) | |
| if self.simple_output: | |
| if self.renderer.using_fine: | |
| rgb = outputs.fine.rgb | |
| depth = outputs.fine.depth | |
| else: | |
| rgb = outputs.coarse.rgb | |
| depth = outputs.coarse.depth | |
| return rgb, depth | |
| else: | |
| # Make DotMap to dict to support DataParallel | |
| return outputs.toDict() | |
| class NeRFRenderer(torch.nn.Module): | |
| """ | |
| NeRF differentiable renderer | |
| :param n_coarse number of coarse (binned uniform) samples | |
| :param n_fine number of fine (importance) samples | |
| :param n_fine_depth number of expected depth samples | |
| :param noise_std noise to add to sigma. We do not use it | |
| :param depth_std noise for depth samples | |
| :param eval_batch_size ray batch size for evaluation | |
| :param white_bkgd if true, background color is white; else black | |
| :param lindisp if to use samples linear in disparity instead of distance | |
| :param sched ray sampling schedule. list containing 3 lists of equal length. | |
| sched[0] is list of iteration numbers, | |
| sched[1] is list of coarse sample numbers, | |
| sched[2] is list of fine sample numbers | |
| """ | |
| def __init__( | |
| self, | |
| n_coarse=128, | |
| n_fine=0, | |
| n_fine_depth=0, | |
| noise_std=0.0, | |
| depth_std=0.01, | |
| eval_batch_size=100000, | |
| white_bkgd=False, | |
| lindisp=False, | |
| sched=None, # ray sampling schedule for coarse and fine rays | |
| hard_alpha_cap=False, | |
| render_mode="volumetric", | |
| surface_sigmoid_scale=.1, | |
| render_flow=False, | |
| normalize_dino=False, | |
| ): | |
| super().__init__() | |
| self.n_coarse, self.n_fine = n_coarse, n_fine | |
| self.n_fine_depth = n_fine_depth | |
| self.noise_std = noise_std | |
| self.depth_std = depth_std | |
| self.eval_batch_size = eval_batch_size | |
| self.white_bkgd = white_bkgd | |
| self.lindisp = lindisp | |
| if lindisp: | |
| print("Using linear displacement rays") | |
| self.using_fine = n_fine > 0 | |
| self.sched = sched | |
| if sched is not None and len(sched) == 0: | |
| self.sched = None | |
| self.register_buffer( | |
| "iter_idx", torch.tensor(0, dtype=torch.long), persistent=True | |
| ) | |
| self.register_buffer( | |
| "last_sched", torch.tensor(0, dtype=torch.long), persistent=True | |
| ) | |
| self.hard_alpha_cap = hard_alpha_cap | |
| assert render_mode in ("volumetric", "surface", "neus") | |
| self.render_mode = render_mode | |
| self.only_surface_color = (self.render_mode == "surface") | |
| self.surface_sigmoid_scale = surface_sigmoid_scale | |
| self.render_flow = render_flow | |
| self.normalize_dino = normalize_dino | |
| def sample_coarse(self, rays): | |
| """ | |
| Stratified sampling. Note this is different from original NeRF slightly. | |
| :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
| :return (B, Kc) | |
| """ | |
| device = rays.device | |
| near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) | |
| step = 1.0 / self.n_coarse | |
| B = rays.shape[0] | |
| z_steps = torch.linspace(0, 1 - step, self.n_coarse, device=device) # (Kc) | |
| z_steps = z_steps.unsqueeze(0).repeat(B, 1) # (B, Kc) | |
| z_steps += torch.rand_like(z_steps) * step | |
| if not self.lindisp: # Use linear sampling in depth space | |
| return near * (1 - z_steps) + far * z_steps # (B, Kf) | |
| else: # Use linear sampling in disparity space | |
| return 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) | |
| # Use linear sampling in depth space | |
| return near * (1 - z_steps) + far * z_steps # (B, Kc) | |
| def sample_coarse_from_dist(self, rays, weights, z_samp): | |
| device = rays.device | |
| B = rays.shape[0] | |
| num_bins = weights.shape[-1] | |
| num_samples = self.n_coarse | |
| weights = weights.detach() + 1e-5 # Prevent division by zero | |
| pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) | |
| cdf = torch.cumsum(pdf, -1) # (B, Kc) | |
| cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) | |
| u = torch.rand(B, num_samples, dtype=torch.float32, device=device) # (B, Kf) | |
| interval_ids = torch.searchsorted(cdf, u, right=True) - 1 # (B, Kf) | |
| interval_ids = torch.clamp(interval_ids, 0, num_samples - 1) | |
| interval_interp = torch.rand_like(interval_ids, dtype=torch.float32) | |
| # z_samps describe the centers of the respective histogram bins. Therefore, we have to extend them to the left and right | |
| if self.lindisp: | |
| z_samp = 1 / z_samp | |
| centers = 0.5 * (z_samp[:, 1:] + z_samp[:, :-1]) | |
| interval_borders = torch.cat((z_samp[:, :1], centers, z_samp[:, -1:]), dim=-1) | |
| left_border = torch.gather(interval_borders, dim=-1, index=interval_ids) | |
| right_border = torch.gather(interval_borders, dim=-1, index=interval_ids + 1) | |
| z_samp_new = ( | |
| left_border * (1 - interval_interp) + right_border * interval_interp | |
| ) | |
| if self.lindisp: | |
| z_samp_new = 1 / z_samp_new | |
| assert not torch.any(torch.isnan(z_samp_new)) | |
| return z_samp_new | |
| def sample_fine(self, rays, weights): | |
| """min | |
| Weighted stratified (importance) sample | |
| :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
| :param weights (B, Kc) | |
| :return (B, Kf-Kfd) | |
| """ | |
| device = rays.device | |
| B = rays.shape[0] | |
| weights = weights.detach() + 1e-5 # Prevent division by zero | |
| pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) | |
| cdf = torch.cumsum(pdf, -1) # (B, Kc) | |
| cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) | |
| u = torch.rand( | |
| B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device | |
| ) # (B, Kf) | |
| inds = torch.searchsorted(cdf, u, right=True).float() - 1.0 # (B, Kf) | |
| inds = torch.clamp_min(inds, 0.0) | |
| z_steps = (inds + torch.rand_like(inds)) / self.n_coarse # (B, Kf) | |
| near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) | |
| if not self.lindisp: # Use linear sampling in depth space | |
| z_samp = near * (1 - z_steps) + far * z_steps # (B, Kf) | |
| else: # Use linear sampling in disparity space | |
| z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) | |
| assert not torch.any(torch.isnan(z_samp)) | |
| return z_samp | |
| def sample_fine_depth(self, rays, depth): | |
| """ | |
| Sample around specified depth | |
| :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
| :param depth (B) | |
| :return (B, Kfd) | |
| """ | |
| z_samp = depth.unsqueeze(1).repeat((1, self.n_fine_depth)) | |
| z_samp += torch.randn_like(z_samp) * self.depth_std | |
| # Clamp does not support tensor bounds | |
| z_samp = torch.max(torch.min(z_samp, rays[:, 7:8]), rays[:, 6:7]) | |
| assert not torch.any(torch.isnan(z_samp)) | |
| return z_samp | |
| def composite(self, model, rays, z_samp, coarse=True, sb=0): | |
| """ | |
| Render RGB and depth for each ray using NeRF alpha-compositing formula, | |
| given sampled positions along each ray (see sample_*) | |
| :param model should return (B, (r, g, b, sigma)) when called with (B, (x, y, z)) | |
| should also support 'coarse' boolean argument | |
| :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
| :param z_samp z positions sampled for each ray (B, K) | |
| :param coarse whether to evaluate using coarse NeRF | |
| :param sb super-batch dimension; 0 = disable | |
| :return weights (B, K), rgb (B, 3), depth (B) | |
| """ | |
| with profiler.record_function("renderer_composite"): | |
| B, K = z_samp.shape | |
| r_dim = rays.shape[-1] | |
| deltas = z_samp[:, 1:] - z_samp[:, :-1] # (B, K-1) | |
| delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # infty (B, 1) | |
| # delta_inf = rays[:, -1:] - z_samp[:, -1:] | |
| deltas = torch.cat([deltas, delta_inf], -1) # (B, K) | |
| # (B, K, 3) | |
| points = rays[:, None, :3] + z_samp.unsqueeze(2) * rays[:, None, 3:6] | |
| points = points.reshape(-1, 3) # (B*K, 3) | |
| if r_dim > 8: | |
| ray_info = rays[:, None, 8:].expand(-1, K, -1) | |
| else: | |
| ray_info = None | |
| if hasattr(model, "use_viewdirs"): | |
| use_viewdirs = model.use_viewdirs | |
| else: | |
| use_viewdirs = None | |
| viewdirs_all = [] | |
| rgbs_all, invalid_all, sigmas_all, extras_all, state_dicts_all = [], [], [], [], [] | |
| if sb > 0: | |
| points = points.reshape( | |
| sb, -1, 3 | |
| ) # (SB, B'*K, 3) B' is real ray batch size | |
| if ray_info is not None: | |
| ray_info = ray_info.reshape(sb, -1, ray_info.shape[-1]) | |
| eval_batch_dim = 1 | |
| eval_batch_size = (self.eval_batch_size - 1) // sb + 1 | |
| else: | |
| eval_batch_size = self.eval_batch_size | |
| eval_batch_dim = 0 | |
| split_points = torch.split(points, eval_batch_size, dim=eval_batch_dim) | |
| if ray_info is not None: | |
| split_ray_infos = torch.split(ray_info, eval_batch_size, dim=eval_batch_dim) | |
| else: | |
| split_ray_infos = [None for _ in split_points] | |
| if use_viewdirs: | |
| dim1 = K | |
| viewdirs = rays[:, None, 3:6].expand(-1, dim1, -1) | |
| if sb > 0: | |
| viewdirs = viewdirs.reshape(sb, -1, 3) # (SB, B'*K, 3) | |
| else: | |
| viewdirs = viewdirs.reshape(-1, 3) # (B*K, 3) | |
| split_viewdirs = torch.split( | |
| viewdirs, eval_batch_size, dim=eval_batch_dim | |
| ) | |
| for i, pnts in enumerate(split_points): | |
| dirs = split_viewdirs[i] | |
| infos = split_ray_infos[i] | |
| rgbs, invalid, sigmas, extras, state_dict = model( | |
| pnts, coarse=coarse, viewdirs=dirs, only_density=self.only_surface_color, ray_info=ray_info, render_flow=self.render_flow | |
| ) | |
| rgbs_all.append(rgbs) | |
| invalid_all.append(invalid) | |
| sigmas_all.append(sigmas) | |
| extras_all.append(extras) | |
| viewdirs_all.append(dirs) | |
| if state_dict is not None: | |
| state_dicts_all.append(state_dict) | |
| else: | |
| for i, pnts in enumerate(split_points): | |
| infos = split_ray_infos[i] | |
| rgbs, invalid, sigmas, extras, state_dict = model(pnts, coarse=coarse, only_density=self.only_surface_color, ray_info=infos, render_flow=self.render_flow) | |
| rgbs_all.append(rgbs) | |
| invalid_all.append(invalid) | |
| sigmas_all.append(sigmas) | |
| extras_all.append(extras) | |
| if state_dict is not None: | |
| state_dicts_all.append(state_dict) | |
| points, viewdirs = None, None | |
| # (B*K, 4) OR (SB, B'*K, 4) | |
| if not self.only_surface_color: | |
| rgbs = torch.cat(rgbs_all, dim=eval_batch_dim) | |
| else: | |
| rgbs = None | |
| invalid = torch.cat(invalid_all, dim=eval_batch_dim) | |
| sigmas = torch.cat(sigmas_all, dim=eval_batch_dim) | |
| if not extras_all[0] is None: | |
| extras = torch.cat(extras_all, dim=eval_batch_dim) | |
| else: | |
| extras = None | |
| deltas = deltas.float() | |
| sigmas = sigmas.float() | |
| if ( | |
| state_dicts_all is not None and len(state_dicts_all) != 0 | |
| ): ## not empty in a list | |
| state_dicts = { | |
| key: torch.cat( | |
| [state_dicts[key] for state_dicts in state_dicts_all], | |
| dim=eval_batch_dim, | |
| ) | |
| for key in state_dicts_all[0].keys() | |
| } | |
| else: | |
| state_dicts = None | |
| if rgbs is not None: | |
| rgbs = rgbs.reshape(B, K, -1) # (B, K, 4 or 5) | |
| invalid = invalid.reshape(B, K, -1) | |
| sigmas = sigmas.reshape(B, K) | |
| if extras is not None: | |
| extras = extras.reshape(B, K, -1) | |
| if state_dicts is not None: | |
| state_dicts = { | |
| key: value.reshape(B, K, *value.shape[2:]) | |
| for key, value in state_dicts.items() | |
| } # BxKx... (BxKxn_viewsx...) | |
| if self.training and self.noise_std > 0.0: | |
| sigmas = sigmas + torch.randn_like(sigmas) * self.noise_std | |
| alphas = 1 - torch.exp( | |
| -deltas.abs() * torch.relu(sigmas) | |
| ) # (B, K) (delta should be positive anyways) | |
| if self.hard_alpha_cap: | |
| alphas[:, -1] = 1 | |
| deltas, sigmas = None, None | |
| alphas_shifted = torch.cat( | |
| [torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1 | |
| ) # (B, K+1) = [1, a1, a2, ...] | |
| T = torch.cumprod(alphas_shifted, -1) # (B) | |
| weights = alphas * T[:, :-1] # (B, K) | |
| # alphas = None | |
| alphas_shifted = None | |
| depth_final = torch.sum(weights * z_samp, -1) # (B) | |
| state_dicts["dino_features"] = torch.sum(state_dicts["dino_features"].mul_(weights.unsqueeze(-1)), -2) | |
| if self.render_mode == "neus": | |
| # dist_from_surf = z_samp - depth_final[..., None] | |
| indices = torch.arange(0, weights.shape[-1], device=weights.device, dtype=weights.dtype).unsqueeze(0) | |
| surface_index = torch.sum(weights * indices, dim=-1, keepdim=True) | |
| dist_from_surf = surface_index - indices | |
| weights = torch.exp(-.5 * (dist_from_surf * self.surface_sigmoid_scale) ** 2) | |
| weights = weights / torch.sum(weights, dim=-1, keepdim=True) | |
| if not self.only_surface_color: | |
| rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2) # (B, 3) | |
| else: | |
| surface_points = rays[:, None, :3] + depth_final[:, None, None] * rays[:, None, 3:6] | |
| surface_points = surface_points.reshape(sb, -1, 3) | |
| if ray_info is not None: | |
| ray_info = ray_info.reshape(sb, -1, K, ray_info.shape[-1])[:, :, 0, :] | |
| rgb_final, invalid_colors = model.sample_colors(surface_points, ray_info=ray_info, render_flow=self.render_flow) | |
| rgb_final = rgb_final.permute(0, 2, 1, 3).reshape(B, -1) | |
| invalid_colors = invalid_colors.permute(0, 2, 1, 3).reshape(B, 1, -1) | |
| invalid = ((invalid > .5) | invalid_colors).float() | |
| if self.white_bkgd: | |
| # White background | |
| pix_alpha = weights.sum(dim=1) # (B), pixel alpha | |
| rgb_final = rgb_final + 1 - pix_alpha.unsqueeze(-1) # (B, 3) | |
| if extras is not None: | |
| extras_final = torch.sum(weights.unsqueeze(-1) * extras, -2) # (B, extras) | |
| else: | |
| extras_final = None | |
| for name, x in [("weights", weights), ("rgb_final", rgb_final), ("depth_final", depth_final), ("alphas", alphas), ("invalid", invalid), ("z_samp", z_samp)]: | |
| if torch.any(torch.isnan(x)): | |
| print(f"Detected NaN in {name} ({x.dtype}):") | |
| print(x) | |
| exit() | |
| if ray_info is not None: | |
| ray_info = rays[:, None, 8:] | |
| # return (weights, rgb_final, depth_final, alphas, invalid, z_samp, rgbs, viewdirs) | |
| return ( | |
| weights, | |
| rgb_final, | |
| depth_final, | |
| alphas, | |
| invalid, | |
| z_samp, | |
| rgbs, | |
| ray_info, | |
| extras_final, | |
| state_dicts, | |
| ) | |
| def forward( | |
| self, | |
| model, | |
| rays, | |
| want_weights=False, | |
| want_alphas=False, | |
| want_z_samps=False, | |
| want_rgb_samps=False, | |
| sample_from_dist=None, | |
| ): | |
| """ | |
| :model nerf model, should return (SB, B, (r, g, b, sigma)) | |
| when called with (SB, B, (x, y, z)), for multi-object: | |
| SB = 'super-batch' = size of object batch, | |
| B = size of per-object ray batch. | |
| Should also support 'coarse' boolean argument for coarse NeRF. | |
| :param rays ray spec [origins (3), directions (3), near (1), far (1)] (SB, B, 8) | |
| :param want_weights if true, returns compositing weights (SB, B, K) | |
| :return render dict | |
| """ | |
| with profiler.record_function("renderer_forward"): | |
| if self.sched is not None and self.last_sched.item() > 0: | |
| self.n_coarse = self.sched[1][self.last_sched.item() - 1] | |
| self.n_fine = self.sched[2][self.last_sched.item() - 1] | |
| assert len(rays.shape) == 3 | |
| superbatch_size = rays.shape[0] | |
| r_dim = rays.shape[-1] | |
| rays = rays.reshape(-1, r_dim) # (SB * B, 8) | |
| if sample_from_dist is None: | |
| z_coarse = self.sample_coarse(rays) # (B, Kc) | |
| else: | |
| prop_weights, prop_z_samp = sample_from_dist | |
| n_samples = prop_weights.shape[-1] | |
| prop_weights = prop_weights.reshape(-1, n_samples) | |
| prop_z_samp = prop_z_samp.reshape(-1, n_samples) | |
| z_coarse = self.sample_coarse_from_dist(rays, prop_weights, prop_z_samp) | |
| z_coarse, _ = torch.sort(z_coarse, dim=-1) | |
| coarse_composite = self.composite( | |
| model, | |
| rays, | |
| z_coarse, | |
| coarse=True, | |
| sb=superbatch_size, | |
| ) | |
| outputs = DotMap( | |
| coarse=self._format_outputs( | |
| coarse_composite, | |
| superbatch_size, | |
| want_weights=want_weights, | |
| want_alphas=want_alphas, | |
| want_z_samps=want_z_samps, | |
| want_rgb_samps=want_rgb_samps, | |
| ), | |
| ) | |
| outputs.state_dict = coarse_composite[-1] | |
| if self.using_fine: | |
| all_samps = [z_coarse] | |
| if self.n_fine - self.n_fine_depth > 0: | |
| all_samps.append( | |
| self.sample_fine(rays, coarse_composite[0].detach()) | |
| ) # (B, Kf - Kfd) | |
| if self.n_fine_depth > 0: | |
| all_samps.append( | |
| self.sample_fine_depth(rays, coarse_composite[2]) | |
| ) # (B, Kfd) | |
| z_combine = torch.cat(all_samps, dim=-1) # (B, Kc + Kf) | |
| z_combine_sorted, argsort = torch.sort(z_combine, dim=-1) | |
| fine_composite = self.composite( | |
| model, | |
| rays, | |
| z_combine_sorted, | |
| coarse=False, | |
| sb=superbatch_size, | |
| ) | |
| outputs.fine = self._format_outputs( | |
| fine_composite, | |
| superbatch_size, | |
| want_weights=want_weights, | |
| want_alphas=want_alphas, | |
| want_z_samps=want_z_samps, | |
| want_rgb_samps=want_rgb_samps, | |
| ) | |
| return outputs | |
| def _format_outputs( | |
| self, | |
| rendered_outputs, | |
| superbatch_size, | |
| want_weights=False, | |
| want_alphas=False, | |
| want_z_samps=False, | |
| want_rgb_samps=False, | |
| ): | |
| ( | |
| weights, | |
| rgb_final, | |
| depth, | |
| alphas, | |
| invalid, | |
| z_samps, | |
| rgb_samps, | |
| ray_info, | |
| extras, | |
| state_dict, | |
| ) = rendered_outputs | |
| n_smps = weights.shape[-1] | |
| out_d_rgb = rgb_final.shape[-1] | |
| out_d_i = invalid.shape[-1] | |
| out_d_dino = state_dict["dino_features"].shape[-1] | |
| if superbatch_size > 0: | |
| rgb_final = rgb_final.reshape(superbatch_size, -1, out_d_rgb) | |
| depth = depth.reshape(superbatch_size, -1) | |
| invalid = invalid.reshape(superbatch_size, -1, n_smps, out_d_i) | |
| ret_dict = DotMap(rgb=rgb_final, depth=depth, invalid=invalid) | |
| if ray_info is not None: | |
| ri_shape = ray_info.shape[-1] | |
| ray_info = ray_info.reshape(superbatch_size, -1, ri_shape) | |
| ret_dict.ray_info = ray_info | |
| if extras is not None: | |
| extras_shape = extras.shape[-1] | |
| extras = extras.reshape(superbatch_size, -1, extras_shape) | |
| ret_dict.extras = extras | |
| if want_weights: | |
| weights = weights.reshape(superbatch_size, -1, n_smps) | |
| ret_dict.weights = weights | |
| if want_alphas: | |
| alphas = alphas.reshape(superbatch_size, -1, n_smps) | |
| ret_dict.alphas = alphas | |
| if want_z_samps: | |
| z_samps = z_samps.reshape(superbatch_size, -1, n_smps) | |
| ret_dict.z_samps = z_samps | |
| if want_rgb_samps: | |
| rgb_samps = rgb_samps.reshape(superbatch_size, -1, n_smps, out_d_rgb) | |
| ret_dict.rgb_samps = rgb_samps | |
| if "dino_features" in state_dict: | |
| dino_features = state_dict["dino_features"].reshape(superbatch_size, -1, out_d_dino) | |
| ret_dict.dino_features = dino_features | |
| if "invalid_features" in state_dict: | |
| invalid_features = state_dict["invalid_features"].reshape(superbatch_size, -1, n_smps, out_d_i) | |
| ret_dict.invalid_features = invalid_features | |
| return ret_dict | |
| def sched_step(self, steps=1): | |
| """ | |
| Called each training iteration to update sample numbers | |
| according to schedule | |
| """ | |
| if self.sched is None: | |
| return | |
| self.iter_idx += steps | |
| while ( | |
| self.last_sched.item() < len(self.sched[0]) | |
| and self.iter_idx.item() >= self.sched[0][self.last_sched.item()] | |
| ): | |
| self.n_coarse = self.sched[1][self.last_sched.item()] | |
| self.n_fine = self.sched[2][self.last_sched.item()] | |
| print( | |
| "INFO: NeRF sampling resolution changed on schedule ==> c", | |
| self.n_coarse, | |
| "f", | |
| self.n_fine, | |
| ) | |
| self.last_sched += 1 | |
| def from_conf(cls, conf, white_bkgd=False, eval_batch_size=100000): | |
| return cls( | |
| conf.get("n_coarse", 128), | |
| conf.get("n_fine", 0), | |
| n_fine_depth=conf.get("n_fine_depth", 0), | |
| noise_std=conf.get("noise_std", 0.0), | |
| depth_std=conf.get("depth_std", 0.01), | |
| white_bkgd=conf.get("white_bkgd", white_bkgd), | |
| lindisp=conf.get("lindisp", True), | |
| eval_batch_size=conf.get("eval_batch_size", eval_batch_size), | |
| sched=conf.get("sched", None), | |
| hard_alpha_cap=conf.get("hard_alpha_cap", False), | |
| render_mode=conf.get("render_mode", "volumetric"), | |
| surface_sigmoid_scale=conf.get("surface_sigmoid_scale", 1), | |
| render_flow=conf.get("render_flow", False), | |
| normalize_dino=conf.get("normalize_dino", False), | |
| ) | |
| def bind_parallel(self, net, gpus=None, simple_output=False): | |
| """ | |
| Returns a wrapper module compatible with DataParallel. | |
| Specifically, it renders rays with this renderer | |
| but always using the given network instance. | |
| Specify a list of GPU ids in 'gpus' to apply DataParallel automatically. | |
| :param net A PixelNeRF network | |
| :param gpus list of GPU ids to parallize to. If length is 1, | |
| does not parallelize | |
| :param simple_output only returns rendered (rgb, depth) instead of the | |
| full render output map. Saves data tranfer cost. | |
| :return torch module | |
| """ | |
| wrapped = _RenderWrapper(net, self, simple_output=simple_output) | |
| if gpus is not None and len(gpus) > 1: | |
| print("Using multi-GPU", gpus) | |
| wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1) | |
| return wrapped | |