get_full_finetune_fsdp_wrap_policy¶
- torchtune.training.get_full_finetune_fsdp_wrap_policy(memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool][source]¶
Retrieves an FSDP wrapping policy based on the specified flags
memory_efficient_fsdp_wrapandmodules_to_wrap. Specifically, ifmemory_efficient_fsdp_wrapis set toTrue, the returned policy will wrap the model’s token embedding and output projection in addition to the modules specified to maximize memory savings.- Parameters:
memory_efficient_fsdp_wrap (bool) – If
True, will also wrap embedding and output projection layers with FSDP.modules_to_wrap (Set[Type]) – Set of module types to wrap.
Note
memory_efficient_fsdp_wrapmemory improvements have currently only been verified on llama3 workloads where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads have not been verified and may not see the same improvements.- Returns:
Wrapping policy that can be passed into
FullyShardedDataParallelas theauto_wrap_policyargument. Please see documentation forFSDPPolicyTypefor additional details.- Return type:
FSDPPolicyType