| """ |
| Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 |
| """ |
|
|
| from typing import Callable, Iterable, Sequence, Union |
|
|
| import torch |
|
|
|
|
| def checkpoint( |
| func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], |
| inputs: Sequence[torch.Tensor], |
| params: Iterable[torch.Tensor], |
| flag: bool, |
| ): |
| """ |
| Evaluate a function without caching intermediate activations, allowing for |
| reduced memory at the expense of extra compute in the backward pass. |
| :param func: the function to evaluate. |
| :param inputs: the argument sequence to pass to `func`. |
| :param params: a sequence of parameters `func` depends on but does not |
| explicitly take as arguments. |
| :param flag: if False, disable gradient checkpointing. |
| """ |
| if flag: |
| args = tuple(inputs) + tuple(params) |
| return CheckpointFunction.apply(func, len(inputs), *args) |
| else: |
| return func(*inputs) |
|
|
|
|
| class CheckpointFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, run_function, length, *args): |
| ctx.run_function = run_function |
| ctx.input_tensors = list(args[:length]) |
| ctx.input_params = list(args[length:]) |
| with torch.no_grad(): |
| output_tensors = ctx.run_function(*ctx.input_tensors) |
| return output_tensors |
|
|
| @staticmethod |
| def backward(ctx, *output_grads): |
| ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| with torch.enable_grad(): |
| |
| |
| |
| shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| output_tensors = ctx.run_function(*shallow_copies) |
| input_grads = torch.autograd.grad( |
| output_tensors, |
| ctx.input_tensors + ctx.input_params, |
| output_grads, |
| allow_unused=True, |
| ) |
| del ctx.input_tensors |
| del ctx.input_params |
| del output_tensors |
| return (None, None) + input_grads |
|
|