AutoUnit¶
-
class
torchtnt.framework.auto_unit.AutoUnit(*args, **kwargs)¶ The AutoUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization and data parallel replication handled for them. The AutoUnit subclasses
TrainUnit,EvalUnit, andPredictUnitand implements thetrain_step,eval_step, andpredict_stepmethods for the user.For the
train_stepit runs:- forward pass and loss computation
- backward pass
- optimizer step
For the
eval_stepit only runs forward and loss computation.For the
predict_stepit only runs forward computation.To benefit from the AutoUnit, the user must subclass it and implement the
compute_lossandconfigure_optimizers_and_lr_schedulermethods. Additionally, the AutoUnit offers these optional hooks:on_train_step_endon_eval_step_endon_predict_step_end
Then use with the
train(),evaluate(),fit(), orpredict()entry point as normal.For more advanced customization, directly use the
TrainUnit,EvalUnit, andPredictUnitinterfaces.Parameters: - module – module to be used during training/evaluation.
- device – the device to be used.
- strategy – the data parallelization strategy to be used. if a string, must be one of
ddporfsdp. - step_lr_interval – whether to step lr_scheduler every step or every epoch. Defaults to every epoch.
- precision – the precision to use in training/evaluation, as either a string or a torch.dtype.
- gradient_accumulation_steps – how many batches to accumulate gradients over.
- detect_anomaly – whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
- clip_grad_norm – max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
- clip_grad_value – max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
- swa_params – params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
- torch_compile_params – params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
- activation_checkpoint_params – params for enabling activation checkpointing
- training – if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
Note
Stochastic Weight Averaging is currently not supported with the FSDP strategy.
Note
Torch compile support is only available in PyTorch 2.0 or higher.
-
abstract
compute_loss(state: State, data: TData) Tuple[Tensor, Any]¶ The user should implement this method with their loss computation. This will be called every
train_step/eval_step.Parameters: - state – a State object which is passed from the
train_step/eval_step - data – a batch of data which is passed from the
train_step/eval_step
Returns: Tuple containing the loss and the output of the model
Note
The module’s forward pass must be run as part of this method.
- state – a State object which is passed from the
-
abstract
configure_optimizers_and_lr_scheduler(module: Module) Tuple[Optimizer, Optional[LRScheduler]]¶ The user should implement this method with their optimizer and learning rate scheduler construction code. This will be called upon initialization of the AutoUnit.
Parameters: module – the module with which to construct optimizer and lr_scheduler Returns: A tuple containing optimizer and optionally the learning rate scheduler
-
eval_step(state: State, data: TData) Tuple[Tensor, Any]¶ Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()for improved performance.Parameters: - state – a
Stateobject containing metadata about the evaluation run. - data – one batch of evaluation data.
- state – a
-
move_data_to_device(state: State, data: TData, non_blocking: bool) TData¶ The user can override this method with custom code to copy data to device. This will be called at the start of every
train_step/eval_step. By default this uses the utility functioncopy_data_to_device().If on GPU, this method will be called on a separate CUDA stream.
Parameters: - state – a State object which is passed from the
train_step/eval_step - data – a batch of data which is passed from the
train_step/eval_step - non_blocking – parameter to pass to
torch.tensor.to
Returns: A batch of data which is on the device
- state – a State object which is passed from the
-
on_eval_step_end(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None¶ This will be called at the end of every
eval_stepbefore returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
eval_step - data – a batch of data which is passed from the
eval_step - step – how many steps have been completed (
train_steps when running fit andeval_steps when running evaluation) - loss – the loss computed in the
compute_lossfunction - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
on_predict_step_end(state: State, data: TData, step: int, outputs: Any) None¶ This will be called at the end of every
predict_stepbefore returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
predict_step - data – a batch of data which is passed from the
predict_step - step – how many ``predict_step``s have been completed
- outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
on_train_end(state: State) None¶ Note that if using SWA and implementing on_train_end(), must call super().on_train_end().
-
on_train_epoch_end(state: State) None¶ Note: if overriding
on_train_epoch_end, remember to callsuper().on_train_epoch_end()
-
on_train_step_end(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None¶ This will be called at the end of every
train_stepbefore returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
train_step - data – a batch of data which is passed from the
train_step - step – how many
train_steps have been completed - loss – the loss computed in the
compute_lossfunction - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
predict_step(state: State, data: TData) Any¶ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()for improved performance.Parameters: - state – a
Stateobject containing metadata about the prediction run. - data – one batch of prediction data.
- state – a
-
train_step(state: State, data: Iterator[TData]) Tuple[Tensor, Any]¶ Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes.
Parameters: - state – a
Stateobject containing metadata about the training run. - data – one batch of training data.
- state – a
-
class
torchtnt.framework.auto_unit.AutoPredictUnit(*, module: Module, device: Optional[device] = None, strategy: Optional[Union[Strategy, str]] = None, precision: Optional[Union[str, dtype]] = None, torch_compile_params: Optional[TorchCompileParams] = None, detect_anomaly: Optional[bool] = None)¶ -
move_data_to_device(state: State, data: TPredictData, non_blocking: bool) TPredictData¶ The user can override this method with custom code to copy data to device. This will be called at the start of every
predict_step. By default this uses the utility functioncopy_data_to_device().If on GPU, this method will be called on a separate CUDA stream.
Parameters: - state – a State object which is passed from the
predict_step - data – a batch of data which is passed from the
predict_step - non_blocking – parameter to pass to
torch.tensor.to
Returns: A batch of data which is on the device
- state – a State object which is passed from the
-
on_predict_step_end(state: State, data: TPredictData, step: int, outputs: Any) None¶ This will be called at the end of every
predict_stepbefore returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
predict_step - data – a batch of data which is passed from the
predict_step - step – how many
predict_steps have been completed - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
predict_step(state: State, data: Iterator[TPredictData]) Any¶ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()for improved performance.Parameters: - state – a
Stateobject containing metadata about the prediction run. - data – one batch of prediction data.
- state – a
-