.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_apps/lightning_classy_vision/data.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_apps_lightning_classy_vision_data.py: Trainer Datasets Example ======================== This is the datasets used for the training example. It's using stock Pytorch Lightning + Classy Vision libraries. .. GENERATED FROM PYTHON SOURCE LINES 14-25 .. code-block:: default import os.path import tarfile from typing import Optional, Callable import fsspec import pytorch_lightning as pl from classy_vision.dataset.classy_dataset import ClassyDataset from torch.utils.data import DataLoader from torchvision import datasets, transforms .. GENERATED FROM PYTHON SOURCE LINES 26-28 This uses classy vision to define a dataset that we will then later use in our Pytorch Lightning data module. .. GENERATED FROM PYTHON SOURCE LINES 28-50 .. code-block:: default class TinyImageNetDataset(ClassyDataset): """ TinyImageNetDataset is a ClassyDataset for the tiny imagenet dataset. """ def __init__(self, data_path: str, transform: Callable[[object], object]) -> None: batchsize_per_replica = 16 shuffle = False num_samples = 1000 dataset = datasets.ImageFolder(data_path) super().__init__( # pyre-fixme[6] dataset, batchsize_per_replica, shuffle, transform, num_samples, ) .. GENERATED FROM PYTHON SOURCE LINES 51-53 For easy of use, we define a lightning data module so we can reuse it across our trainer and other components that need to load data. .. GENERATED FROM PYTHON SOURCE LINES 53-109 .. code-block:: default # pyre-fixme[13]: Attribute `test_ds` is never initialized. # pyre-fixme[13]: Attribute `train_ds` is never initialized. # pyre-fixme[13]: Attribute `val_ds` is never initialized. class TinyImageNetDataModule(pl.LightningDataModule): """ TinyImageNetDataModule is a pytorch LightningDataModule for the tiny imagenet dataset. """ train_ds: TinyImageNetDataset val_ds: TinyImageNetDataset test_ds: TinyImageNetDataset def __init__(self, data_dir: str, batch_size: int = 16) -> None: super().__init__() self.data_dir = data_dir self.batch_size = batch_size def setup(self, stage: Optional[str] = None) -> None: # Setup data loader and transforms img_transform = transforms.Compose( [ transforms.Grayscale(), transforms.ToTensor(), ] ) self.train_ds = TinyImageNetDataset( data_path=os.path.join(self.data_dir, "train"), transform=lambda x: (img_transform(x[0]), x[1]), ) self.val_ds = TinyImageNetDataset( data_path=os.path.join(self.data_dir, "val"), transform=lambda x: (img_transform(x[0]), x[1]), ) self.test_ds = TinyImageNetDataset( data_path=os.path.join(self.data_dir, "test"), transform=lambda x: (img_transform(x[0]), x[1]), ) def train_dataloader(self) -> DataLoader: # pyre-fixme[6] return DataLoader(self.train_ds, batch_size=self.batch_size) def val_dataloader(self) -> DataLoader: # pyre-fixme[6]: return DataLoader(self.val_ds, batch_size=self.batch_size) def test_dataloader(self) -> DataLoader: # pyre-fixme[6] return DataLoader(self.test_ds, batch_size=self.batch_size) def teardown(self, stage: Optional[str] = None) -> None: pass .. GENERATED FROM PYTHON SOURCE LINES 110-112 To pass data between the different components we use fsspec which allows us to read/write to cloud or local file storage. .. GENERATED FROM PYTHON SOURCE LINES 112-131 .. code-block:: default def download_data(remote_path: str, tmpdir: str) -> str: """ download_data downloads the training data from the specified remote path via fsspec and places it in the tmpdir unextracted. """ tar_path = os.path.join(tmpdir, "data.tar.gz") print(f"downloading dataset from {remote_path} to {tar_path}...") fs, _, rpaths = fsspec.get_fs_token_paths(remote_path) assert len(rpaths) == 1, "must have single path" fs.get(rpaths[0], tar_path) data_path = os.path.join(tmpdir, "data") print(f"extracting {tar_path} to {data_path}...") with tarfile.open(tar_path, mode="r") as f: f.extractall(data_path) return data_path .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_examples_apps_lightning_classy_vision_data.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: data.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: data.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_