Spaces:
Running
on
Zero
Running
on
Zero
| from random import shuffle | |
| import random | |
| from typing import Callable, Optional | |
| import numpy as np | |
| import torch | |
| EncoderSamplingStrategy = Callable[[int], list[int]] | |
| LossSamplingStrategy = Callable[[int], tuple[list[int], list[int], Optional[list[list[bool]]]]] | |
| # ============================================ ENCODING SAMPLING STRATEGIES ============================================ | |
| def default_encoder_sampler() -> EncoderSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> list[int]: | |
| return [0] | |
| return _sampling_strategy | |
| def kitti_360_full_encoder_sampler( | |
| num_encoder_frames: int, always_use_base_frame: bool = True | |
| ) -> EncoderSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> list[int]: | |
| if always_use_base_frame: | |
| encoder_perm = (torch.randperm(num_frames - 1) + 1)[ | |
| : num_encoder_frames - 1 | |
| ].tolist() | |
| ids_encoder = [0] | |
| ids_encoder.extend(encoder_perm) | |
| else: | |
| ids_encoder = (torch.randperm(num_frames - 1) + 1)[ | |
| :num_encoder_frames | |
| ].tolist() | |
| return ids_encoder | |
| return _sampling_strategy | |
| def kitti_360_stereo_encoder_sampler( | |
| num_encoder_frames: int, num_stereo_frames: int, always_use_base_frame: bool = True | |
| ) -> EncoderSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> list[int]: | |
| num_frames = min(num_frames, num_stereo_frames) | |
| if always_use_base_frame: | |
| encoder_perm = (torch.randperm(num_frames - 1) + 1)[ | |
| : num_encoder_frames - 1 | |
| ].tolist() | |
| ids_encoder = [0] | |
| ids_encoder.extend(encoder_perm) | |
| else: | |
| ids_encoder = (torch.randperm(num_frames - 1) + 1)[ | |
| :num_encoder_frames | |
| ].tolist() | |
| return ids_encoder | |
| return _sampling_strategy | |
| def get_encoder_sampling(config) -> EncoderSamplingStrategy: | |
| strategy = config.get("name", None) | |
| match strategy: | |
| case "kitti_360_full": | |
| return kitti_360_full_encoder_sampler(**config["args"]) | |
| case "kitti_360_stereo": | |
| return kitti_360_stereo_encoder_sampler(**config["args"]) | |
| case _: | |
| return default_encoder_sampler() | |
| # =============================================== LOSS SAMPLING STRATEGIES ============================================= | |
| def single_view_loss_sampler( | |
| shuffle_frames: bool = False, all_frames: bool = False | |
| ) -> LossSamplingStrategy: | |
| if all_frames: | |
| starting_frame = 0 | |
| else: | |
| starting_frame = 1 | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| frames = [id for id in range(num_frames)] | |
| if shuffle_frames: | |
| shuffle(frames) | |
| return frames[0:1], frames[starting_frame:], None | |
| return _sampling_strategy | |
| def single_view_renderer_sampler( | |
| shuffle_frames: bool = False, all_frames: bool = False | |
| ) -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| frames = [id for id in range(num_frames)] | |
| if shuffle_frames: | |
| shuffle(frames) | |
| if all_frames: | |
| return frames, frames[0:1], None | |
| else: | |
| return frames[0:-1], frames[0:1], None | |
| return _sampling_strategy | |
| def stereo_view_loss_sampler(shuffle_frames: bool = False) -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| all_frames = [id for id in range(num_frames)] | |
| if shuffle_frames: | |
| shuffle(all_frames) | |
| if all_frames[0] < num_frames // 2: | |
| ids_loss = list(range(num_frames // 2)) | |
| ids_renderer = list(range(num_frames // 2, num_frames)) | |
| else: | |
| ids_renderer = list(range(num_frames // 2)) | |
| ids_loss = list(range(num_frames // 2, num_frames)) | |
| return ids_loss, ids_renderer, None | |
| return _sampling_strategy | |
| def kitti_360_loss_sampler() -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| ids_loss: list[int] = [] | |
| ids_renderer: list[int] = [] | |
| for cam_pair_base_id in range(0, num_frames, 2): | |
| if random.randint(0, 2): | |
| ids_loss.append(cam_pair_base_id) | |
| ids_renderer.append(cam_pair_base_id + 1) | |
| else: | |
| ids_loss.append(cam_pair_base_id + 1) | |
| ids_renderer.append(cam_pair_base_id) | |
| return ids_loss, ids_renderer, None | |
| return _sampling_strategy | |
| def kitti_360_loss_sampler() -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| ids_loss: list[int] = [] | |
| ids_renderer: list[int] = [] | |
| for cam_pair_base_id in range(0, num_frames, 2): | |
| if random.randint(0, 2): | |
| ids_loss.append(cam_pair_base_id) | |
| ids_renderer.append(cam_pair_base_id + 1) | |
| else: | |
| ids_loss.append(cam_pair_base_id + 1) | |
| ids_renderer.append(cam_pair_base_id) | |
| return ids_loss, ids_renderer, None | |
| return _sampling_strategy | |
| def kitti_360_with_mapping_loss_sampler() -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| ids_loss: list[int] = [] | |
| ids_renderer: list[int] = [] | |
| mapping = [] | |
| for cam_pair_base_id in range(0, num_frames, 2): | |
| if random.randint(0, 2): | |
| ids_loss.append(cam_pair_base_id) | |
| ids_renderer.append(cam_pair_base_id + 1) | |
| mapping.append([len(ids_renderer) - 1]) | |
| else: | |
| ids_loss.append(cam_pair_base_id + 1) | |
| ids_renderer.append(cam_pair_base_id) | |
| mapping.append([len(ids_renderer) - 1]) | |
| mapping = np.array(mapping, dtype=np.int64) | |
| return ids_loss, ids_renderer, mapping | |
| return _sampling_strategy | |
| def waymo_with_mapping_loss_sampler() -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| ids_loss: list[int] = [] | |
| ids_renderer: list[int] = [] | |
| mapping = [] | |
| for cam_pair_base_id in range(0, num_frames, 2): | |
| if random.randint(0, 2): | |
| ids_loss.append(cam_pair_base_id) | |
| ids_renderer.append(cam_pair_base_id + 1) | |
| mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]]) | |
| else: | |
| ids_loss.append(cam_pair_base_id + 1) | |
| ids_renderer.append(cam_pair_base_id) | |
| mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]]) | |
| mapping = np.array(mapping, dtype=np.int64) | |
| return ids_loss, ids_renderer, mapping | |
| return _sampling_strategy | |
| def alternate_loss_sampler() -> LossSamplingStrategy: | |
| def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
| frames = [id for id in range(num_frames)] | |
| if random.randint(0, 2): | |
| return list(range(0, num_frames, 2)), list(range(1, num_frames, 2)), None | |
| else: | |
| return list(range(1, num_frames, 2)), list(range(0, num_frames, 2)), None | |
| return _sampling_strategy | |
| def get_loss_renderer_sampling(config) -> EncoderSamplingStrategy: | |
| strategy = config.get("name", None) | |
| match strategy: | |
| case "single_loss": | |
| return single_view_loss_sampler(**config.get("args", {})) | |
| case "single_renderer": | |
| return single_view_renderer_sampler(**config.get("args", {})) | |
| case "stereo_loss": | |
| return stereo_view_loss_sampler(**config.get("args", {})) | |
| case "kitti_360": | |
| return kitti_360_loss_sampler() | |
| case "kitti_360_with_mapping": | |
| return kitti_360_with_mapping_loss_sampler() | |
| case "waymo_with_mapping": | |
| return waymo_with_mapping_loss_sampler() | |
| case "alternate": | |
| return alternate_loss_sampler() | |
| case _: | |
| return single_view_loss_sampler(False) | |
| # old sampling strategies | |
| # if self.training: | |
| # frame_perm = torch.randperm(v) | |
| # else: | |
| # frame_perm = torch.arange(v) ## eval | |
| # if self.enc_style == "random": ## encoded views | |
| # encoder_perm = (torch.randperm(v - 1) + 1)[ | |
| # : self.nv_ - 1 | |
| # ].tolist() ## nv-1 for mono [0] idx | |
| # ids_encoder = [0] ## always starts sampling from mono cam | |
| # ids_encoder.extend(encoder_perm) ## add more cam_views randomly incl. fe | |
| # elif self.enc_style == "default": | |
| # ids_encoder = [ | |
| # v_ for v_ in range(self.nv_) | |
| # ] ## iterating view(v_) over num_views(nv_) | |
| # elif self.enc_style == "stereo": | |
| # if self.training: | |
| # # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views") | |
| # # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views") | |
| # encoder_perm = (torch.randperm(v - (1 + 4)) + 1)[ | |
| # : self.nv_ - 1 | |
| # ].tolist() | |
| # ids_encoder = [0] | |
| # ids_encoder.extend(encoder_perm) | |
| # else: | |
| # ids_encoder = [0] | |
| # else: | |
| # raise NotImplementedError(f"__unrecognized enc_style: {self.enc_style}") | |
| # ## default: ids_encoder = [0,1,2,3] <=> front stereo for 1st + 2nd time stamps | |
| # if ( | |
| # not self.training and self.ids_enc_viz_eval | |
| # ): ## when eval in viz to be standardized with test: it's eval from line 354, base_trainer.py | |
| # ids_encoder = self.ids_enc_viz_eval ## fixed during eval | |
| # ids_render = torch.sort( | |
| # frame_perm[[i for i in self.frames_render if i < v]] | |
| # ).values ## ? ### tensor([0, 4]) | |
| # combine_ids = None | |
| # if self.training: | |
| # if self.frame_sample_mode == "only": | |
| # ids_loss = [0] | |
| # ids_render = ids_render[ids_render != 0] | |
| # elif self.frame_sample_mode == "not": | |
| # frame_perm = torch.randperm(v - 1) + 1 | |
| # ids_loss = torch.sort( | |
| # frame_perm[[i for i in self.frames_render if i < v - 1]] | |
| # ).values | |
| # ids_render = [i for i in range(v) if i not in ids_loss] | |
| # elif self.frame_sample_mode == "stereo": | |
| # if frame_perm[0] < v // 2: | |
| # ids_loss = list(range(v // 2)) | |
| # ids_render = list(range(v // 2, v)) | |
| # else: | |
| # ids_loss = list(range(v // 2, v)) | |
| # ids_render = list(range(v // 2)) | |
| # elif self.frame_sample_mode == "mono": | |
| # split_i = v // 2 | |
| # if frame_perm[0] < v // 2: | |
| # ids_loss = list(range(0, split_i, 2)) + list( | |
| # range(split_i + 1, v, 2) | |
| # ) | |
| # ids_render = list(range(1, split_i, 2)) + list(range(split_i, v, 2)) | |
| # else: | |
| # ids_loss = list(range(1, split_i, 2)) + list(range(split_i, v, 2)) | |
| # ids_render = list(range(0, split_i, 2)) + list( | |
| # range(split_i + 1, v, 2) | |
| # ) | |
| # elif self.frame_sample_mode == "kitti360-mono": | |
| # steps = v // 4 | |
| # start_from = 0 if frame_perm[0] < v // 2 else 1 | |
| # ids_loss, ids_render = [], [] | |
| # for cam in range( | |
| # 4 | |
| # ): ## stereo cam sampled for each time ## ! c.f. paper: N_{render}, N_{loss} | |
| # ids_loss += [cam * steps + i for i in range(start_from, steps, 2)] | |
| # ids_render += [ | |
| # cam * steps + i for i in range(1 - start_from, steps, 2) | |
| # ] | |
| # start_from = 1 - start_from | |
| # if self.enc_style == "test": | |
| # ids_encoder = ids_loss[: self.nv_] | |
| # elif self.frame_sample_mode.startswith("waymo"): | |
| # num_views = int(self.frame_sample_mode.split("-")[-1]) | |
| # steps = v // num_views | |
| # split = steps // 2 | |
| # # Predict features from half-left, center, half-right | |
| # ids_encoder = [0, steps, steps * 2] | |
| # # Combine all frames half-left, center, half-right for efficiency reasons | |
| # combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)] | |
| # if self.training: | |
| # step_perm = torch.randperm(steps) | |
| # else: | |
| # step_perm = torch.arange(steps) ## eval | |
| # step_perm = step_perm.tolist() | |
| # ids_loss = sum( | |
| # [ | |
| # [i + j * steps for j in range(num_views)] | |
| # for i in step_perm[:split] | |
| # ], | |
| # [], | |
| # ) | |
| # ids_render = sum( | |
| # [ | |
| # [i + j * steps for j in range(num_views)] | |
| # for i in step_perm[split:] | |
| # ], | |
| # [], | |
| # ) | |
| # elif self.frame_sample_mode == "default": | |
| # ids_loss = frame_perm[ | |
| # [i for i in range(v) if frame_perm[i] not in ids_render] | |
| # ] | |
| # else: | |
| # raise NotImplementedError | |
| # else: ## eval (!= self.training) | |
| # ids_loss = torch.arange(v) | |
| # ids_render = [0] | |
| # if self.frame_sample_mode.startswith("waymo"): | |
| # num_views = int(self.frame_sample_mode.split("-")[-1]) | |
| # steps = v // num_views | |
| # split = steps // 2 | |
| # # Predict features from half-left, center, half-right | |
| # ids_encoder = [0, steps, steps * 2] | |
| # ids_render = [0, steps, steps * 2] | |
| # combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)] | |