| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict, Optional |
| |
|
| | import torch |
| |
|
| | |
| | substrings_to_ignore = [ |
| | "_extra_state", |
| | ] |
| |
|
| |
|
| | def get_partial_state_dict( |
| | state_dict: Dict[str, torch.Tensor], |
| | prefix: str, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Get a partial state dict with keys starting with the given prefix |
| | """ |
| | return {k: v for k, v in state_dict.items() if k.startswith(prefix)} |
| |
|
| |
|
| | def process_state_dict( |
| | state_dict: Dict[str, torch.Tensor], |
| | device: str = None, |
| | dtype: torch.dtype = None, |
| | prefix_to_remove: Optional[str] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) |
| | - Move tensors to specified device and dtype if provided |
| | |
| | Args: |
| | state_dict (Dict[str, torch.Tensor]): The state dict to process |
| | device (str, optional): The device to move tensors to. Defaults to None. |
| | dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. |
| | prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: The processed state dict |
| | """ |
| | new_state_dict = {} |
| | tensor_kwargs = {} |
| | if device is not None: |
| | tensor_kwargs["device"] = device |
| | if dtype is not None: |
| | tensor_kwargs["dtype"] = dtype |
| |
|
| | for key, value in state_dict.items(): |
| | |
| | skip = False |
| | for substr in substrings_to_ignore: |
| | if substr in key: |
| | skip = True |
| | break |
| | if skip: |
| | continue |
| | if len(tensor_kwargs) > 0: |
| | value = value.to(**tensor_kwargs) |
| | if prefix_to_remove is not None and key.startswith(prefix_to_remove): |
| | key = key[len(prefix_to_remove) :] |
| | new_state_dict[key] = value |
| | return new_state_dict |
| |
|