Shortcuts

Trainer Component Example

This is a component definition that runs the example lightning_classy_vision app.

from typing import Optional

import torchx.specs.api as torchx
from torchx.components.base import named_resource
from torchx.components.base.binary_component import binary_component


def trainer(
    image: str,
    output_path: str,
    data_path: str,
    load_path: str = "",
    log_path: str = "/logs",
    resource: Optional[str] = None,
) -> torchx.AppDef:
    """Runs the example lightning_classy_vision app.

    Args:
        image: image to run (e.g. foobar:latest)
        output_path: output path for model checkpoints (e.g. file:///foo/bar)
        load_path: path to load pretrained model from
        data_path: path to the data to load
        log_path: path to save tensorboard logs to
        resource: the resources to use
    """
    return binary_component(
        name="examples-lightning_classy_vision-trainer",
        entrypoint="lightning_classy_vision/train.py",
        args=[
            "--output_path",
            output_path,
            "--load_path",
            load_path,
            "--log_pat",
            log_path,
            "--data_path",
            data_path,
        ],
        image=image,
        resource=named_resource(resource)
        if resource
        else torchx.Resource(cpu=1, gpu=0, memMB=1024),
    )


def interpret(
    image: str,
    load_path: str,
    data_path: str,
    output_path: str,
    resource: Optional[str] = None,
) -> torchx.AppDef:
    """Runs the model interpretability app on the model outputted by the training
    component.

    Args:
        image: image to run (e.g. foobar:latest)
        load_path: path to load pretrained model from
        data_path: path to the data to load
        output_path: output path for model checkpoints (e.g. file:///foo/bar)
        resource: the resources to use
    """
    return binary_component(
        name="examples-lightning_classy_vision-interpret",
        entrypoint="lightning_classy_vision/interpret.py",
        args=[
            "--load_path",
            load_path,
            "--data_path",
            data_path,
            "--output_path",
            output_path,
        ],
        image=image,
        resource=named_resource(resource)
        if resource
        else torchx.Resource(cpu=1, gpu=0, memMB=1024),
    )

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources