Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Any | |
| import ignite | |
| import ignite.distributed as idist | |
| from ignite.engine import Engine, Events, EventsList | |
| import torch | |
| from omegaconf import OmegaConf | |
| # TODO: move to utils or similar | |
| def event_list_from_config(config) -> EventsList: | |
| events = EventsList() | |
| if isinstance(config, int): | |
| events = events | Events.EPOCH_COMPLETED(every=config) | Events.COMPLETED | |
| else: | |
| for event in config: | |
| if event["args"]: | |
| events = events | Events[event["type"]](**event["args"]) | |
| else: | |
| events = events | Events[event["type"]] | |
| return events | |
| def global_step_fn(trainer: Engine, config: dict[str, Any]): | |
| match config.get("type", None): | |
| case "trainer epoch": | |
| return lambda engine, event_name: trainer.state.epoch | |
| case "trainer iteration": | |
| return lambda engine, event_name: trainer.state.iteration | |
| case _: | |
| raise ValueError(f"Unknown global step type: {config['type']}") | |
| # trainer iteration | |
| gst = lambda engine, event_name: trainer.state.iteration | |
| # # iteration per epoch | |
| # gst_it_epoch = ( | |
| # lambda engine, event_name: (trainer.state.epoch - 1) | |
| # * engine.state.epoch_length | |
| # + engine.state.iteration | |
| # - 1 | |
| # ) | |
| # gst_it_iters = ( | |
| # lambda engine, event_name: ( | |
| # ( | |
| # (trainer.state.epoch - 1) * trainer.state.epoch_length | |
| # + trainer.state.iteration | |
| # ) | |
| # // every | |
| # ) | |
| # * engine.state.epoch_length | |
| # + engine.state.iteration | |
| # - 1 | |
| # ) | |
| # gst_ep_iters = lambda engine, event_name: ( | |
| # ( | |
| # (trainer.state.epoch - 1) * trainer.state.epoch_length | |
| # + trainer.state.iteration | |
| # ) | |
| # // every | |
| # ) | |
| def log_basic_info(logger, config): | |
| logger.info(f"Run {config['name']}") | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"Ignite version: {ignite.__version__}") | |
| if torch.cuda.is_available(): | |
| # explicitly import cudnn as | |
| # torch.backends.cudnn can not be pickled with hvd spawning procs | |
| from torch.backends import cudnn | |
| logger.info(f"GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}") | |
| logger.info(f"CUDA version: {torch.version.cuda}") | |
| logger.info(f"CUDNN version: {cudnn.version()}") | |
| if idist.get_world_size() > 1: | |
| logger.info("\nDistributed setting:") | |
| logger.info(f"\tbackend: {idist.backend()}") | |
| logger.info(f"\tworld size: {idist.get_world_size()}") | |
| logger.info("\n") | |
| logger.info("\n") | |
| logger.info(f"Configuration: \n{OmegaConf.to_yaml(config)}") | |