Shortcuts

Source code for torchx.components.dist

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
For distributed training, TorchX relies on the scheduler's gang scheduling
capabilities to schedule ``n`` copies of nodes. Once launched, the application
is expected to be written in a way that leverages this topology, for instance,
with PyTorch's
`DDP <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_.
You can express a variety of node topologies with TorchX by specifying multiple
:py:class:`torchx.specs.Role` in your component's AppDef. Each role maps to
a homogeneous group of nodes that performs a "role" (function) in the overall
training. Scheduling-wise, TorchX launches each role as a sub-gang.


A DDP-style training job has a single role: trainers. Whereas a
training job that uses parameter servers will have two roles: parameter server, trainer.
You can specify different entrypoint (executable), num replicas, resource requirements,
and more for each role.


DDP Builtin
----------------

DDP-style trainers are common and easy to templetize since they are homogeneous
single role AppDefs, so there is a builtin: ``dist.ddp``. Assuming your DDP
training script is called ``main.py``, launch it as:

.. code:: shell-session

    # locally, 1 node x 4 workers
    $ torchx run -s local_cwd dist.ddp -j 1x4 --script main.py

    # locally, 2 node x 4 workers (8 total)
    $ torchx run -s local_cwd dist.ddp -j 2x4 --script main.py

    # remote (optionally pass --rdzv_port to use a different master port than the default 29500)
    $ torchx run -s kubernetes -cfg queue=default dist.ddp \\
        -j 2x4 \\
        --script main.py \\


Note that the only difference compared to the local launch is the scheduler (``-s``).
The ``dist.ddp`` builtin uses ``torchelastic`` (more specifically ``torch.distributed.run``)
under the hood. Read more about torchelastic `here <https://pytorch.org/docs/stable/elastic/run.html>`_.

Components APIs
-----------------
"""
import os
import shlex
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import torchx
import torchx.specs as specs
from torchx.specs import macros


[docs]def ddp( *script_args: str, script: Optional[str] = None, m: Optional[str] = None, image: str = torchx.IMAGE, name: Optional[str] = None, h: Optional[str] = None, cpu: int = 2, gpu: int = 0, memMB: int = 1024, j: str = "1x2", env: Optional[Dict[str, str]] = None, max_retries: int = 0, rdzv_port: int = 29500, mounts: Optional[List[str]] = None, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). Uses `torch.distributed.run <https://pytorch.org/docs/stable/distributed.elastic.html>`_ to launch and coordinate PyTorch worker processes. Defaults to using ``c10d`` rendezvous backend on rendezvous_endpoint ``$rank_0_host:$rdzv_port``. Note that ``rdzv_port`` parameter is ignored when running on single node, and instead we use port 0 which instructs torchelastic to chose a free random port on the host. Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where ``h`` takes precedence if specified for setting resource requirements. See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_. Args: script_args: arguments to the main module script: script or binary to run within the image m: the python module path to run image: image (e.g. docker) name: job name override (uses the script name if not specified) cpu: number of cpus per replica gpu: number of gpus per replica memMB: cpu memory in MB per replica h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) max_retries: the number of scheduler retries allowed rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous. Only takes effect when running multi-node. When running single node, this parameter is ignored and a random free port is chosen. mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]). See scheduler documentation for more info. """ if (script is None) == (m is None): raise ValueError("exactly one of --script and -m must be specified") rep = j.split("x") if len(rep) == 1: # num replicas only nnodes = 1 nproc_per_node = int(rep[0]) elif len(rep) == 2: nnodes = int(rep[0]) nproc_per_node = int(rep[1]) else: raise ValueError(f"Invalid format for -j, usage example: 1x4. Given: {j}") if script: # script name/module no extension role_name = Path(script).stem elif m: role_name = m.rpartition(".")[2] else: raise ValueError("failed to compute role_name") rdzv_backend = "c10d" if nnodes == 1: # using port 0 makes elastic chose a free random port which is ok # for single-node jobs since all workers run under a single agent rdzv_endpoint = "localhost:0" else: rdzv_endpoint = _noquote(f"$${macros.rank0_env}:{rdzv_port}") if env is None: env = {} env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING")) cmd = [ "python", "-m", "torch.distributed.run", "--rdzv_backend", rdzv_backend, "--rdzv_endpoint", rdzv_endpoint, "--rdzv_id", f"{macros.app_id}", "--nnodes", str(nnodes), "--nproc_per_node", str(nproc_per_node), "--tee", "3", "--role", "", ] if script is not None: cmd += [script] elif m is not None: cmd += ["-m", m] cmd += script_args return specs.AppDef( name=name or role_name, roles=[ specs.Role( name=role_name, image=image, entrypoint="bash", num_replicas=nnodes, resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), args=["-c", _args_join(cmd)], env=env, port_map={ "c10d": 29500, }, max_retries=max_retries, mounts=specs.parse_mounts(mounts) if mounts else [], ) ], )
def _args_join(args: Iterable[str]) -> str: """ _args_join is like shlex.join but if the argument is wrapped in _noquote it'll not quote that argument. """ quoted = [arg if isinstance(arg, _noquote) else shlex.quote(arg) for arg in args] return " ".join(quoted) class _noquote(str): """ _noquote is a wrapper around str that indicates that the argument shouldn't be passed through shlex.quote. """ pass

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources