torchtnt.utils.distributed.revert_sync_batchnorm¶
-
torchtnt.utils.distributed.revert_sync_batchnorm(module: Module, device: Optional[Union[str, device]] = None) Module¶ Helper function to convert all
torch.nn.SyncBatchNormlayers in the module toBatchNorm*Dlayers. This function revertstorch.nn.SyncBatchNorm.convert_sync_batchnorm().Parameters: - module (nn.Module) – module containing one or more
torch.nn.SyncBatchNormlayers - device (optional) – device in which the
BatchNorm*Dshould be created, default is cpu
Returns: The original
modulewith the convertedBatchNorm*Dlayers. If the originalmoduleis atorch.nn.SyncBatchNormlayer, a newBatchNorm*Dlayer object will be returned instead. Note that theBatchNorm*Dlayers returned will not have input dimension information.Example:
>>> # Network with nn.BatchNorm layer >>> module = torch.nn.Sequential( >>> torch.nn.Linear(20, 100), >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) >>> reverted_module = revert_sync_batchnorm(sync_bn_module, torch.device("cuda"))
- module (nn.Module) – module containing one or more