Fit¶
Fit Entry Point¶
-
torchtnt.framework.fit.fit(unit: TrainUnit[TTrainData], train_dataloader: Iterable[TTrainData], eval_dataloader: Iterable[TEvalData], *, max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_train_steps_per_epoch: Optional[int] = None, max_eval_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, evaluate_every_n_epochs: Optional[int] = 1, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None¶ The
fitentry point interleaves training and evaluation loops. Thefitentry point takes in an object which subclasses bothTrainUnitandEvalUnit, train and eval dataloaders (any Iterables), optional arguments to modify loop execution, and runs the fit loop.Parameters: - unit – an instance that subclasses both
TrainUnitandEvalUnit, implementingtrain_step()andeval_step(). - train_dataloader – dataloader to be used during training, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- eval_dataloader – dataloader to be used during evaluation, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- max_epochs – the max number of epochs to run for training.
Nonemeans no limit (infinite training) unless stopped by max_steps. - max_steps – the max number of steps to run for training.
Nonemeans no limit (infinite training) unless stopped by max_epochs. - max_train_steps_per_epoch – the max number of steps to run per epoch for training. None means train until
train_dataloaderis exhausted. - max_eval_steps_per_epoch – the max number of steps to run per epoch for evaluation. None means evaluate until
eval_dataloaderis exhausted. - evaluate_every_n_steps – how often to run the evaluation loop in terms of training steps.
- evaluate_every_n_epochs – how often to run the evaluation loop in terms of training epochs.
- callbacks – an optional list of callbacks.
- timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).
Below is an example of calling
fit().from torchtnt.framework.fit import fit fit_unit = MyFitUnit(module=..., optimizer=..., lr_scheduler=...) train_dataloader = torch.utils.data.DataLoader(...) eval_dataloader = torch.utils.data.DataLoader(...) fit(fit_unit, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_epochs=4)
Below is pseudocode of what the
fit()entry point does.set unit's tracked modules to train mode call on_train_start on unit first and then callbacks while training is not done: while epoch is not done: call on_train_epoch_start on unit first and then callbacks try: data = next(dataloader) call on_train_step_start on callbacks call train_step on unit increment step counter call on_train_step_end on callbacks if should evaluate after this step: run eval loops except StopIteration: break increment epoch counter call on_train_epoch_end on unit first and then callbacks if should evaluate after this epoch: run eval loop call on_train_end on unit first and then callbacks- unit – an instance that subclasses both