set_activation_checkpointing¶
- torchtune.training.set_activation_checkpointing(model: Module, auto_wrap_policy: Union[Set[Type], Callable[[Module, bool, int], bool]], **kwargs) None[source]¶
Utility to apply activation checkpointing to the passed-in model.
- Parameters:
model (nn.Module) – Model to apply activation checkpointing to.
auto_wrap_policy (ACWrapPolicyType) – Policy to wrap module. This can either be a set of
nn.Moduletypes, in which case, modules of the specified type(s) will be wrapped individually with activation checkpointing, or acallablepolicy describing how to wrap the model with activation checkpointing. For more information on authoring custom policies, please see this tutorial: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html#transformer-wrapping-policy.**kwargs – additional arguments to pass to
torch.distributedactivation checkpointing.