| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import logging |
| |
|
| | import os |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| | ) |
| | logger = logging.getLogger("repcodec_train") |
| |
|
| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | import yaml |
| | from torch.utils.data import DataLoader |
| |
|
| | from dataloader import ReprDataset, ReprCollater |
| | from losses.repr_reconstruct_loss import ReprReconstructLoss |
| | from repcodec.RepCodec import RepCodec |
| | from trainer.autoencoder import Trainer |
| |
|
| |
|
| | class TrainMain: |
| | def __init__(self, args): |
| | |
| | random.seed(args.seed) |
| | np.random.seed(args.seed) |
| | torch.manual_seed(args.seed) |
| | if not torch.cuda.is_available(): |
| | self.device = torch.device('cpu') |
| | logger.info(f"device: cpu") |
| | else: |
| | self.device = torch.device('cuda:0') |
| | logger.info(f"device: gpu") |
| | torch.cuda.manual_seed_all(args.seed) |
| | if args.disable_cudnn == "False": |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | |
| | with open(args.config, 'r') as f: |
| | self.config = yaml.load(f, Loader=yaml.FullLoader) |
| | self.config.update(vars(args)) |
| |
|
| | |
| | expdir = os.path.join(args.exp_root, args.tag) |
| | os.makedirs(expdir, exist_ok=True) |
| | self.config["outdir"] = expdir |
| |
|
| | |
| | with open(os.path.join(expdir, "config.yml"), "w") as f: |
| | yaml.dump(self.config, f, Dumper=yaml.Dumper) |
| | for key, value in self.config.items(): |
| | logger.info(f"{key} = {value}") |
| |
|
| | |
| | self.resume: str = args.resume |
| | self.data_loader = None |
| | self.model = None |
| | self.optimizer = None |
| | self.scheduler = None |
| | self.criterion = None |
| | self.trainer = None |
| |
|
| | |
| | self.batch_length: int = self.config['batch_length'] |
| | self.data_path: str = self.config['data']['path'] |
| |
|
| | def initialize_data_loader(self): |
| | train_set = self._build_dataset("train") |
| | valid_set = self._build_dataset("valid") |
| | collater = ReprCollater() |
| |
|
| | logger.info(f"The number of training files = {len(train_set)}.") |
| | logger.info(f"The number of validation files = {len(valid_set)}.") |
| | dataset = {"train": train_set, "dev": valid_set} |
| | self._set_data_loader(dataset, collater) |
| |
|
| | def define_model_optimizer_scheduler(self): |
| | |
| | self.model = { |
| | "repcodec": RepCodec(**self.config["model_params"]).to(self.device) |
| | } |
| | logger.info(f"Model Arch:\n{self.model['repcodec']}") |
| |
|
| | |
| | optimizer_class = getattr( |
| | torch.optim, |
| | self.config["model_optimizer_type"] |
| | ) |
| | self.optimizer = { |
| | "repcodec": optimizer_class( |
| | self.model["repcodec"].parameters(), |
| | **self.config["model_optimizer_params"] |
| | ) |
| | } |
| |
|
| | |
| | scheduler_class = getattr( |
| | torch.optim.lr_scheduler, |
| | self.config.get("model_scheduler_type", "StepLR"), |
| | ) |
| | self.scheduler = { |
| | "repcodec": scheduler_class( |
| | optimizer=self.optimizer["repcodec"], |
| | **self.config["model_scheduler_params"] |
| | ) |
| | } |
| |
|
| | def define_criterion(self): |
| | self.criterion = { |
| | "repr_reconstruct_loss": ReprReconstructLoss( |
| | **self.config.get("repr_reconstruct_loss_params", {}), |
| | ).to(self.device) |
| | } |
| |
|
| | def define_trainer(self): |
| | self.trainer = Trainer( |
| | steps=0, |
| | epochs=0, |
| | data_loader=self.data_loader, |
| | model=self.model, |
| | criterion=self.criterion, |
| | optimizer=self.optimizer, |
| | scheduler=self.scheduler, |
| | config=self.config, |
| | device=self.device |
| | ) |
| |
|
| | def initialize_model(self): |
| | initial = self.config.get("initial", "") |
| | if os.path.exists(self.resume): |
| | self.trainer.load_checkpoint(self.resume) |
| | logger.info(f"Successfully resumed from {self.resume}.") |
| | elif os.path.exists(initial): |
| | self.trainer.load_checkpoint(initial, load_only_params=True) |
| | logger.info(f"Successfully initialize parameters from {initial}.") |
| | else: |
| | logger.info("Train from scrach") |
| |
|
| | def run(self): |
| | assert self.trainer is not None |
| | self.trainer: Trainer |
| | try: |
| | logger.info(f"The current training step: {self.trainer.steps}") |
| | self.trainer.train_max_steps = self.config["train_max_steps"] |
| | if not self.trainer._check_train_finish(): |
| | self.trainer.run() |
| | finally: |
| | self.trainer.save_checkpoint( |
| | os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl") |
| | ) |
| | logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.") |
| |
|
| | def _build_dataset( |
| | self, subset: str |
| | ) -> ReprDataset: |
| | data_dir = os.path.join( |
| | self.data_path, self.config['data']['subset'][subset] |
| | ) |
| | params = { |
| | "data_dir": data_dir, |
| | "batch_len": self.batch_length |
| | } |
| | return ReprDataset(**params) |
| |
|
| | def _set_data_loader(self, dataset, collater): |
| | self.data_loader = { |
| | "train": DataLoader( |
| | dataset=dataset["train"], |
| | shuffle=True, |
| | collate_fn=collater, |
| | batch_size=self.config["batch_size"], |
| | num_workers=self.config["num_workers"], |
| | pin_memory=self.config["pin_memory"], |
| | ), |
| | "dev": DataLoader( |
| | dataset=dataset["dev"], |
| | shuffle=False, |
| | collate_fn=collater, |
| | batch_size=self.config["batch_size"], |
| | num_workers=0, |
| | pin_memory=False, |
| | ), |
| | } |
| |
|
| |
|
| | def train(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "-c", "--config", type=str, required=True, |
| | help="the path of config yaml file." |
| | ) |
| | parser.add_argument( |
| | "--tag", type=str, required=True, |
| | help="the outputs will be saved to exp_root/tag/" |
| | ) |
| | parser.add_argument( |
| | "--exp_root", type=str, default="exp" |
| | ) |
| | parser.add_argument( |
| | "--resume", default="", type=str, nargs="?", |
| | help='checkpoint file path to resume training. (default="")', |
| | ) |
| | parser.add_argument("--seed", default=1337, type=int) |
| | parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN") |
| | args = parser.parse_args() |
| |
|
| | train_main = TrainMain(args) |
| | train_main.initialize_data_loader() |
| | train_main.define_model_optimizer_scheduler() |
| | train_main.define_criterion() |
| | train_main.define_trainer() |
| | train_main.initialize_model() |
| | train_main.run() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | train() |
| |
|