| import argparse |
| import math |
| import sys |
| sys.path.append("..") |
| import numpy as np |
| import os |
| import torch |
|
|
| import trimesh |
|
|
| from datasets import Object_Occ,Scale_Shift_Rotate |
| from models import get_model |
| from pathlib import Path |
| import open3d as o3d |
| from configs.config_utils import CONFIG |
| import tqdm |
| from util import misc |
| from datasets.taxonomy import synthetic_arkit_category_combined |
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser('', add_help=False) |
| parser.add_argument('--configs',type=str,required=True) |
| parser.add_argument('--ae-pth',type=str) |
| parser.add_argument("--category",nargs='+', type=str) |
| parser.add_argument('--world_size', default=1, type=int, |
| help='number of distributed processes') |
| parser.add_argument('--local_rank', default=-1, type=int) |
| parser.add_argument('--dist_on_itp', action='store_true') |
| parser.add_argument('--dist_url', default='env://', |
| help='url used to set up distributed training') |
| parser.add_argument('--device', default='cuda', |
| help='device to use for training / testing') |
| parser.add_argument("--batch_size", default=1, type=int) |
| parser.add_argument("--data-pth",default="../data",type=str) |
|
|
| args = parser.parse_args() |
| misc.init_distributed_mode(args) |
| device = torch.device(args.device) |
|
|
| config_path=args.configs |
| config=CONFIG(config_path) |
| dataset_config=config.config['dataset'] |
| dataset_config['data_path']=args.data_pth |
| |
| transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True) |
| if len(args.category)==1 and args.category[0]=="all": |
| category=synthetic_arkit_category_combined["all"] |
| else: |
| category=args.category |
| train_dataset = Object_Occ(dataset_config['data_path'], split="train", |
| categories=category, |
| transform=transform, sampling=True, |
| num_samples=1024, return_surface=True, |
| surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) |
| val_dataset = Object_Occ(dataset_config['data_path'], split="val", |
| categories=category, |
| transform=transform, sampling=True, |
| num_samples=1024, return_surface=True, |
| surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) |
| num_tasks = misc.get_world_size() |
| global_rank = misc.get_rank() |
| train_sampler = torch.utils.data.DistributedSampler( |
| train_dataset, num_replicas=num_tasks, rank=global_rank, |
| shuffle=False) |
| val_sampler=torch.utils.data.DistributedSampler( |
| val_dataset, num_replicas=num_tasks, rank=global_rank, |
| shuffle=False) |
| |
| batch_size=args.batch_size |
| train_dataloader=torch.utils.data.DataLoader( |
| train_dataset,sampler=train_sampler, |
| batch_size=batch_size, |
| num_workers=10, |
| shuffle=False, |
| drop_last=False, |
| ) |
| val_dataloader = torch.utils.data.DataLoader( |
| val_dataset, sampler=val_sampler, |
| batch_size=batch_size, |
| num_workers=10, |
| shuffle=False, |
| drop_last=False, |
| ) |
| dataloader_list=[train_dataloader,val_dataloader] |
| |
| output_dir=os.path.join(dataset_config['data_path'],"other_data") |
| |
|
|
| model_config=config.config['model'] |
| model=get_model(model_config) |
| model.load_state_dict(torch.load(args.ae_pth)['model']) |
| model.eval().float().to(device) |
| |
|
|
| with torch.no_grad(): |
| for e in range(5): |
| for dataloader in dataloader_list: |
| for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)): |
| surface = data_batch['surface'].to(device, non_blocking=True) |
| model_ids=data_batch['model_id'] |
| tran_mats=data_batch['tran_mat'] |
| categories=data_batch['category'] |
| with torch.no_grad(): |
| plane_feat,_,means,logvars=model.encode(surface) |
| plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear') |
| vars=torch.exp(logvars) |
| means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear") |
| vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4 |
| sample_logvars=torch.log(vars) |
|
|
| for j in range(means.shape[0]): |
| |
| mean=means[j].float().cpu().numpy() |
| logvar=sample_logvars[j].float().cpu().numpy() |
| tran_mat=tran_mats[j].float().cpu().numpy() |
|
|
| output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j]) |
| Path(output_folder).mkdir(parents=True, exist_ok=True) |
| exist_len=len(os.listdir(output_folder)) |
| save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len)) |
| np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat) |
|
|