Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import DataLoader, Subset | |
| from scenedino.common.io.configs import load_model_config | |
| from scenedino.models import make_model | |
| from scenedino.datasets import make_test_dataset | |
| from scenedino.common.geometry import distance_to_z | |
| from scenedino.renderer import NeRFRenderer | |
| from scenedino.common.ray_sampler import ImageRaySampler, get_ray_sampler | |
| from scenedino.evaluation.base_evaluator import base_evaluation | |
| from scenedino.training.trainer_downstream import BTSDownstreamWrapper | |
| IDX = 0 | |
| logger = logging.getLogger("evaluation") | |
| class BTSWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| renderer, | |
| config, | |
| # evaluation_fns | |
| ) -> None: | |
| super().__init__() | |
| self.renderer = renderer | |
| # TODO: have a consitent sampling range | |
| self.z_near = config.get("z_near", 3.0) | |
| self.z_far = config.get("z_far", 80.0) | |
| self.sampler = ImageRaySampler(self.z_near, self.z_far) | |
| # self.evaluation_fns = evaluation_fns | |
| def get_loss_metric_names(): | |
| return ["loss", "loss_l2", "loss_mask", "loss_temporal"] | |
| def forward(self, data): | |
| data = dict(data) | |
| images = torch.stack(data["imgs"], dim=1) # n, v, c, h, w | |
| poses = torch.stack(data["poses"], dim=1) # n, v, 4, 4 w2c | |
| projs = torch.stack(data["projs"], dim=1) # n, v, 4, 4 (-1, 1) | |
| B, n_frames, c, h, w = images.shape | |
| device = images.device | |
| # Use first frame as keyframe | |
| to_base_pose = torch.inverse(poses[:, :1, :, :]) | |
| poses = to_base_pose.expand(-1, n_frames, -1, -1) @ poses | |
| # TODO: make configurable | |
| ids_encoder = [0] | |
| self.renderer.net.compute_grid_transforms( | |
| projs[:, ids_encoder], poses[:, ids_encoder] | |
| ) | |
| self.renderer.net.encode( | |
| images, | |
| projs, | |
| poses, | |
| ids_encoder=ids_encoder, | |
| ids_render=ids_encoder, | |
| images_alt=images * 0.5 + 0.5, | |
| ) | |
| all_rays, all_rgb_gt = self.sampler.sample(images * 0.5 + 0.5, poses, projs) | |
| data["fine"] = [] | |
| data["coarse"] = [] | |
| self.renderer.net.set_scale(0) | |
| render_dict = self.renderer(all_rays, want_weights=True, want_alphas=True) | |
| if "fine" not in render_dict: | |
| render_dict["fine"] = dict(render_dict["coarse"]) | |
| render_dict["rgb_gt"] = all_rgb_gt | |
| render_dict["rays"] = all_rays | |
| # TODO: check if distance to z is needed | |
| render_dict = self.sampler.reconstruct(render_dict) | |
| render_dict["coarse"]["depth"] = distance_to_z( | |
| render_dict["coarse"]["depth"], projs | |
| ) | |
| render_dict["fine"]["depth"] = distance_to_z( | |
| render_dict["fine"]["depth"], projs | |
| ) | |
| data["fine"].append(render_dict["fine"]) | |
| data["coarse"].append(render_dict["coarse"]) | |
| data["rgb_gt"] = render_dict["rgb_gt"] | |
| data["rays"] = render_dict["rays"] | |
| data["z_near"] = torch.tensor(self.z_near, device=images.device) | |
| data["z_far"] = torch.tensor(self.z_far, device=images.device) | |
| # for eval_fn in self.evaluation_fns: | |
| # data["metrics"].update(eval_fn(data, model=self.renderer.net)) | |
| return data | |
| def evaluation(local_rank, config): | |
| return base_evaluation(local_rank, config, get_dataflow, initialize) | |
| def get_dataflow(config): | |
| test_dataset = make_test_dataset(config["dataset"]) | |
| test_loader = DataLoader( | |
| test_dataset, # Subset(test_dataset, torch.randperm(test_dataset.length)[:1000]), | |
| batch_size=config.get("batch_size", 1), | |
| num_workers=config["num_workers"], | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| return test_loader | |
| def initialize(config: dict): | |
| checkpoint = Path(config["checkpoint"]) | |
| logger.info(f"Loading model config from {checkpoint.parent}") | |
| load_model_config(checkpoint.parent, config) | |
| net = make_model(config["model"], config["downstream"]) | |
| # net = make_model(config["model"]) | |
| renderer = NeRFRenderer.from_conf(config["renderer"]) | |
| renderer = renderer.bind_parallel(net, gpus=None).eval() | |
| # TODO: attach evaluation functions rather that add them to the wrapper | |
| # eval_fns = [] | |
| # for eval_conf in config["evaluations"]: | |
| # eval_fn = make_eval_fn(eval_conf) | |
| # if eval_fn is not None: | |
| # eval_fns.append(eval_fn) | |
| ray_sampler = get_ray_sampler(config["training"]["ray_sampler"]) | |
| model = BTSDownstreamWrapper(renderer, ray_sampler, config["model"]) | |
| # model = BTSWrapper(renderer, config["model"]) | |
| # model = BTSWrapper(renderer, config["model"], eval_fns) | |
| return model | |