.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_apps/lightning_classy_vision/train.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_apps_lightning_classy_vision_train.py: Trainer App Example ============================================= This is an example TorchX app that uses PyTorch Lightning and ClassyVision to train a model. This app only uses standard OSS libraries and has no runtime torchx dependencies. For saving and loading data and models it uses fsspec which makes the app agnostic to the environment it's running in. .. GENERATED FROM PYTHON SOURCE LINES 19-114 .. code-block:: default import argparse import sys import tempfile from typing import List import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger # ensure data and module are on the path sys.path.append("examples/apps/lightning_classy_vision") from data import TinyImageNetDataModule, download_data from model import TinyImageNetModel, export_inference_model def parse_args(argv: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser( description="pytorch lightning + classy vision TorchX example app" ) parser.add_argument( "--epochs", type=int, default=3, help="number of epochs to train" ) parser.add_argument( "--batch_size", type=int, default=32, help="batch size to use for training" ) parser.add_argument( "--data_path", type=str, help="path to load the training data from", required=True, ) parser.add_argument("--skip_export", action="store_true") parser.add_argument("--load_path", type=str, help="checkpoint path to load from") parser.add_argument( "--output_path", type=str, help="path to place checkpoints and model outputs", required=True, ) parser.add_argument( "--log_path", type=str, help="path to place the tensorboard logs", default="/tmp", ) return parser.parse_args(argv) def main(argv: List[str]) -> None: with tempfile.TemporaryDirectory() as tmpdir: args = parse_args(argv) # Init our model model = TinyImageNetModel() # Download and setup the data module data_path = download_data(args.data_path, tmpdir) data = TinyImageNetDataModule( data_dir=data_path, batch_size=args.batch_size, ) # Setup model checkpointing checkpoint_callback = ModelCheckpoint( monitor="train_loss", dirpath=args.output_path, save_last=True, ) if args.load_path: print(f"loading checkpoint: {args.load_path}...") model.load_from_checkpoint(checkpoint_path=args.load_path) logger = TensorBoardLogger( save_dir=args.log_path, version=1, name="lightning_logs" ) # Initialize a trainer trainer = pl.Trainer( logger=logger, max_epochs=args.epochs, callbacks=[checkpoint_callback], ) # Train the model ⚡ trainer.fit(model, data) if not args.skip_export: # Export the inference model export_inference_model(model, args.output_path, tmpdir) if __name__ == "__main__" and "NOTEBOOK" not in globals(): main(sys.argv[1:]) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_examples_apps_lightning_classy_vision_train.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: train.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: train.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_