| | import json |
| | from collections import defaultdict |
| | import os |
| | import shutil |
| | import tarfile |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import pytorch_lightning as pl |
| | import torch |
| | import torch.utils.data as torchdata |
| | from omegaconf import DictConfig |
| |
|
| | from ... import logger |
| | from .dataset import MapLocDataset |
| | from ..sequential import chunk_sequence |
| | from ..torch import collate, worker_init_fn |
| | from ..schema import MIADataConfiguration |
| |
|
| | def pack_dump_dict(dump): |
| | for per_seq in dump.values(): |
| | if "points" in per_seq: |
| | for chunk in list(per_seq["points"]): |
| | points = per_seq["points"].pop(chunk) |
| | if points is not None: |
| | per_seq["points"][chunk] = np.array( |
| | per_seq["points"][chunk], np.float64 |
| | ) |
| | for view in per_seq["views"].values(): |
| | for k in ["R_c2w", "roll_pitch_yaw"]: |
| | view[k] = np.array(view[k], np.float32) |
| | for k in ["chunk_id"]: |
| | if k in view: |
| | view.pop(k) |
| | if "observations" in view: |
| | view["observations"] = np.array(view["observations"]) |
| | for camera in per_seq["cameras"].values(): |
| | for k in ["params"]: |
| | camera[k] = np.array(camera[k], np.float32) |
| | return dump |
| |
|
| |
|
| | class MapillaryDataModule(pl.LightningDataModule): |
| | dump_filename = "dump.json" |
| | images_archive = "images.tar.gz" |
| | images_dirname = "images/" |
| | semantic_masks_dirname = "semantic_masks/" |
| | flood_dirname = "flood_fill/" |
| |
|
| | def __init__(self, cfg: MIADataConfiguration): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.root = self.cfg.data_dir |
| | self.local_dir = None |
| |
|
| | def prepare_data(self): |
| | for scene in self.cfg.scenes: |
| | dump_dir = self.root / scene |
| | assert (dump_dir / self.dump_filename).exists(), dump_dir |
| | |
| | if self.local_dir is None: |
| | assert (dump_dir / self.images_dirname).exists(), dump_dir |
| | continue |
| | assert (dump_dir / self.semantic_masks_dirname).exists(), dump_dir |
| | assert (dump_dir / self.flood_dirname).exists(), dump_dir |
| | |
| | local_dir = self.local_dir / scene |
| | if local_dir.exists(): |
| | shutil.rmtree(local_dir) |
| | local_dir.mkdir(exist_ok=True, parents=True) |
| | images_archive = dump_dir / self.images_archive |
| | logger.info("Extracting the image archive %s.", images_archive) |
| | with tarfile.open(images_archive) as fp: |
| | fp.extractall(local_dir) |
| |
|
| | def setup(self, stage: Optional[str] = None): |
| | self.dumps = {} |
| | |
| | self.image_dirs = {} |
| | self.seg_masks_dir = {} |
| | self.flood_masks_dir = {} |
| | names = [] |
| |
|
| | for scene in self.cfg.scenes: |
| | logger.info("Loading scene %s.", scene) |
| | dump_dir = self.root / scene |
| |
|
| | logger.info("Loading dump json file %s.", self.dump_filename) |
| | with (dump_dir / self.dump_filename).open("r") as fp: |
| | self.dumps[scene] = pack_dump_dict(json.load(fp)) |
| | for seq, per_seq in self.dumps[scene].items(): |
| | for cam_id, cam_dict in per_seq["cameras"].items(): |
| | if cam_dict["model"] != "PINHOLE": |
| | raise ValueError( |
| | f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}" |
| | ) |
| |
|
| | self.image_dirs[scene] = ( |
| | (self.local_dir or self.root) / scene / self.images_dirname |
| | ) |
| | assert self.image_dirs[scene].exists(), self.image_dirs[scene] |
| |
|
| | self.seg_masks_dir[scene] = ( |
| | (self.local_dir or self.root) / scene / self.semantic_masks_dirname |
| | ) |
| | assert self.seg_masks_dir[scene].exists(), self.seg_masks_dir[scene] |
| |
|
| | self.flood_masks_dir[scene] = ( |
| | (self.local_dir or self.root) / scene / self.flood_dirname |
| | ) |
| | assert self.flood_masks_dir[scene].exists(), self.flood_masks_dir[scene] |
| |
|
| | images = set(x.split('.')[0] for x in os.listdir(self.image_dirs[scene])) |
| | flood_masks = set(x.split('.')[0] for x in os.listdir(self.flood_masks_dir[scene])) |
| | semantic_masks = set(x.split('.')[0] for x in os.listdir(self.seg_masks_dir[scene])) |
| |
|
| | for seq, data in self.dumps[scene].items(): |
| | for name in data["views"]: |
| | if name in images and name.split("_")[0] in flood_masks and name.split("_")[0] in semantic_masks: |
| | names.append((scene, seq, name)) |
| | |
| | self.parse_splits(self.cfg.split, names) |
| | if self.cfg.filter_for is not None: |
| | self.filter_elements() |
| | self.pack_data() |
| |
|
| | def pack_data(self): |
| | |
| | exclude = { |
| | "compass_angle", |
| | "compass_accuracy", |
| | "gps_accuracy", |
| | "chunk_key", |
| | "panorama_offset", |
| | } |
| | cameras = { |
| | scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()} |
| | for scene, per_scene in self.dumps.items() |
| | } |
| | points = { |
| | scene: { |
| | seq: { |
| | i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items() |
| | } |
| | for seq, per_seq in per_scene.items() |
| | } |
| | for scene, per_scene in self.dumps.items() |
| | } |
| | self.data = {} |
| |
|
| | |
| | if self.cfg.split == "splits_MGL_13loc.json": |
| | |
| | num_samples_to_move = int(len(self.splits['train']) * 0.2) |
| | samples_to_move = self.splits['train'][-num_samples_to_move:] |
| | self.splits['val'].extend(samples_to_move) |
| | self.splits['train'] = self.splits['train'][:-num_samples_to_move] |
| | print(f"Dataset Len: {len(self.splits['train']), len(self.splits['val'])}\n\n\n\n") |
| | elif self.cfg.split == "splits_MGL_soma_70k_mappred_random.json": |
| | for stage, names in self.splits.items(): |
| | print("Length of splits {}: ".format(stage), len(self.splits[stage])) |
| | for stage, names in self.splits.items(): |
| | view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]] |
| | data = {k: [] for k in view.keys() - exclude} |
| | for scene, seq, name in names: |
| | for k in data: |
| | data[k].append(self.dumps[scene][seq]["views"][name].get(k, None)) |
| | for k in data: |
| | v = np.array(data[k]) |
| | if np.issubdtype(v.dtype, np.integer) or np.issubdtype( |
| | v.dtype, np.floating |
| | ): |
| | v = torch.from_numpy(v) |
| | data[k] = v |
| | data["cameras"] = cameras |
| | data["points"] = points |
| | self.data[stage] = data |
| | self.splits[stage] = np.array(names) |
| |
|
| | def filter_elements(self): |
| | for stage, names in self.splits.items(): |
| | names_select = [] |
| | for scene, seq, name in names: |
| | view = self.dumps[scene][seq]["views"][name] |
| | if self.cfg.filter_for == "ground_plane": |
| | if not (1.0 <= view["height"] <= 3.0): |
| | continue |
| | planes = self.dumps[scene][seq].get("plane") |
| | if planes is not None: |
| | inliers = planes[str(view["chunk_id"])][-1] |
| | if inliers < 10: |
| | continue |
| | if self.cfg.filter_by_ground_angle is not None: |
| | plane = np.array(view["plane_params"]) |
| | normal = plane[:3] / np.linalg.norm(plane[:3]) |
| | angle = np.rad2deg(np.arccos(np.abs(normal[-1]))) |
| | if angle > self.cfg.filter_by_ground_angle: |
| | continue |
| | elif self.cfg.filter_for == "pointcloud": |
| | if len(view["observations"]) < self.cfg.min_num_points: |
| | continue |
| | elif self.cfg.filter_for is not None: |
| | raise ValueError(f"Unknown filtering: {self.cfg.filter_for}") |
| | names_select.append((scene, seq, name)) |
| | logger.info( |
| | "%s: Keep %d/%d images after filtering for %s.", |
| | stage, |
| | len(names_select), |
| | len(names), |
| | self.cfg.filter_for, |
| | ) |
| | self.splits[stage] = names_select |
| |
|
| | def parse_splits(self, split_arg, names): |
| | if split_arg is None: |
| | self.splits = { |
| | "train": names, |
| | "val": names, |
| | } |
| | elif isinstance(split_arg, int): |
| | names = np.random.RandomState(self.cfg.seed).permutation(names).tolist() |
| | self.splits = { |
| | "train": names[split_arg:], |
| | "val": names[:split_arg], |
| | } |
| | elif isinstance(split_arg, float): |
| | names = np.random.RandomState(self.cfg.seed).permutation(names).tolist() |
| | self.splits = { |
| | "train": names[int(split_arg * len(names)) :], |
| | "val": names[: int(split_arg * len(names))], |
| | } |
| | elif isinstance(split_arg, DictConfig): |
| | scenes_val = set(split_arg.val) |
| | scenes_train = set(split_arg.train) |
| | assert len(scenes_val - set(self.cfg.scenes)) == 0 |
| | assert len(scenes_train - set(self.cfg.scenes)) == 0 |
| | self.splits = { |
| | "train": [n for n in names if n[0] in scenes_train], |
| | "val": [n for n in names if n[0] in scenes_val], |
| | } |
| | elif isinstance(split_arg, str): |
| | |
| | if "/" in split_arg: |
| | split_path = self.root / split_arg |
| | else: |
| | split_path = Path(split_arg) |
| | |
| | with split_path.open("r") as fp: |
| | splits = json.load(fp) |
| | splits = { |
| | k: {loc: set(ids) for loc, ids in split.items()} |
| | for k, split in splits.items() |
| | } |
| | self.splits = {} |
| | |
| | for k, split in splits.items(): |
| | self.splits[k] = [ |
| | n |
| | for n in names |
| | if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]] |
| | ] |
| | else: |
| | raise ValueError(split_arg) |
| |
|
| | def dataset(self, stage: str): |
| | return MapLocDataset( |
| | stage, |
| | self.cfg, |
| | self.splits[stage], |
| | self.data[stage], |
| | self.image_dirs, |
| | self.seg_masks_dir, |
| | self.flood_masks_dir, |
| |
|
| | image_ext=".jpg", |
| | ) |
| |
|
| | def sequence_dataset(self, stage: str, **kwargs): |
| | keys = self.splits[stage] |
| | seq2indices = defaultdict(list) |
| | for index, (_, seq, _) in enumerate(keys): |
| | seq2indices[seq].append(index) |
| | |
| | chunk2indices = {} |
| | for seq, indices in seq2indices.items(): |
| | chunks = chunk_sequence(self.data[stage], indices, **kwargs) |
| | for i, sub_indices in enumerate(chunks): |
| | chunk2indices[seq, i] = sub_indices |
| | |
| | chunk_indices = torch.full((len(keys),), -1) |
| | for (_, chunk_index), idx in chunk2indices.items(): |
| | chunk_indices[idx] = chunk_index |
| | self.data[stage]["chunk_index"] = chunk_indices |
| | dataset = self.dataset(stage) |
| | return dataset, chunk2indices |
| |
|
| | def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs): |
| | dataset, chunk2idx = self.sequence_dataset(stage, **kwargs) |
| | chunk_keys = sorted(chunk2idx) |
| | if shuffle: |
| | perm = torch.randperm(len(chunk_keys)) |
| | chunk_keys = [chunk_keys[i] for i in perm] |
| | key_indices = [i for key in chunk_keys for i in chunk2idx[key]] |
| | num_workers = self.cfg.loading[stage]["num_workers"] |
| | loader = torchdata.DataLoader( |
| | dataset, |
| | batch_size=None, |
| | sampler=key_indices, |
| | num_workers=num_workers, |
| | shuffle=False, |
| | pin_memory=True, |
| | persistent_workers=num_workers > 0, |
| | worker_init_fn=worker_init_fn, |
| | collate_fn=collate, |
| | ) |
| | return loader, chunk_keys, chunk2idx |
| |
|