Distributed training¶
The core TNT framework makes no assumptions about distributed training or devices, and expects the user to handle configuring distributed training on their own. As a convenience, the framework offers the AutoUnit for users who prefer for this to be handled automatically. The framework-provided checkpointing callbacks handle distributed model checkpointing and loading.
If you are using the the TrainUnit/ EvalUnit/ PredictUnit interface, you are expected to initialize the CUDA device, if applicable, along with the global process group from torch.distributed. We offer a convenience function init_from_env() that works with TorchElastic to automatically handle these settings for you, which you should invoke at the beginning of your script.
Distributed Data Parallel¶
If you are using the the TrainUnit/ EvalUnit/ PredictUnit interface, DDP can be simply be wrapped around your model like so:
device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in DDP
device_ids = [device.index]
model = torch.nn.parallel.DistributedDataParallel(module, device_ids=device_ids)
We also offer prepare_ddp() which can assist in wrapping the model for you.
The AutoUnit automatically wraps the module in DDP when either
The string
ddpis passed in the strategy argumentmodule = nn.Linear(input_dim, 1) my_auto_unit = MyAutoUnit(module=module, strategy="ddp")
The dataclass
DDPStrategyis passed in to the strategy argument. This is helpful when wanting to customize the settings in DDPmodule = nn.Linear(input_dim, 1) ddp_strategy = DDPStrategy(broadcast_buffers=False, check_reduction=True) my_auto_unit = MyAutoUnit(module=module, strategy=ddp_strategy)
Fully Sharded Data Parallel¶
If using one or more of or TrainUnit, EvalUnit, or PredictUnit, FSDP can be simply be wrapped around the model like so:
device = init_from_env()
module = nn.Linear(input_dim, 1)
# move module to device
module = module.to(device)
# wrap module in FSDP
model = torch.distributed.fsdp.FullyShardedDataParallel(module, device_id=device)
We also offer prepare_fsdp() which can assist in wrapping the model for you.
The AutoUnit automatically wraps the module in FSDP when either
The string
fsdpis passed in the strategy argumentmodule = nn.Linear(input_dim, 1) my_auto_unit = MyAutoUnit(module=module, strategy="fsdp")
The dataclass
FSDPStrategyis passed in to the strategy argument. This is helpful when wanting to customize the settings in FSDPmodule = nn.Linear(input_dim, 1) fsdp_strategy = FSDPStrategy(forward_prefetch=True, limit_all_gathers=True) my_auto_unit = MyAutoUnit(module=module, strategy=fsdp_strategy)