Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union, Dict | |
| from pytorch_lightning import strategies | |
| from lightning_fabric.utilities.types import _PATH | |
| from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict | |
| ''' | |
| overwrite the function in deepspeed | |
| ''' | |
| ### start overwrite ### | |
| def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): | |
| sd = self.module.state_dict(destination, prefix, keep_vars) | |
| # Remove frozen parameter weights from state_dict if specified | |
| if exclude_frozen_parameters: | |
| to_be_removed = [] | |
| for n in sd: | |
| try: | |
| if not self.module.get_parameter(n).requires_grad: | |
| to_be_removed.append(n) | |
| except AttributeError: | |
| to_be_removed.append(n) | |
| for key in to_be_removed: | |
| sd.pop(key) | |
| if self.random_ltd_enabled(): | |
| sd = remove_random_ltd_state_dict(sd) | |
| return sd | |
| from deepspeed import DeepSpeedEngine | |
| DeepSpeedEngine.module_state_dict = module_state_dict | |
| ### end overwrite ### | |
| class MyDeepSpeedStrategy(strategies.DeepSpeedStrategy): | |
| def save_checkpoint_v1( | |
| self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
| ): | |
| """Save model/training states as a checkpoint file through state-dump and file-write. | |
| Args: | |
| checkpoint: dict containing model and trainer state | |
| filepath: write-target file's path | |
| storage_options: parameter for how to save to st | |
| orage, passed to ``CheckpointIO`` plugin | |
| """ | |
| if self.is_global_zero: | |
| self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) | |
| def load_model_state_dict(self, checkpoint): | |
| assert self.lightning_module is not None | |
| self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False) | |
| def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: | |
| """Save model/training states as a checkpoint file through state-dump and file-write. | |
| Args: | |
| checkpoint: The checkpoint state dictionary | |
| filepath: write-target file's path | |
| storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used | |
| Raises: | |
| TypeError: | |
| If ``storage_options`` arg is passed in | |
| """ | |
| # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath | |
| filepath = self.broadcast(filepath) | |
| if storage_options is not None: | |
| raise TypeError( | |
| "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" | |
| f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used." | |
| ) | |
| if self.zero_stage_3 and self._multi_device and self.is_global_zero: | |
| print( | |
| "Warning: When saving the DeepSpeed Stage 3 checkpoint, " | |
| "each worker will save a shard of the checkpoint within a directory. " | |
| "If a single file is required after training, " | |
| "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#" | |
| "deepspeed-zero-stage-3-single-file for instructions." | |
| ) | |
| # Use deepspeed's internal checkpointing function to handle partitioned weights across processes | |
| # dump states as a checkpoint dictionary object | |
| _exclude_keys = ["state_dict", "optimizer_states"] | |
| checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} | |
| self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=True) | |