Shortcuts

Source code for torchx.schedulers.slurm_scheduler

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import csv
import os.path
import shlex
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional

from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, Scheduler
from torchx.specs.api import (
    NONE,
    AppDef,
    AppState,
    Role,
    RunConfig,
    SchedulerBackend,
    macros,
)


SLURM_STATES: Mapping[str, AppState] = {
    "BOOT_FAIL": AppState.FAILED,
    "CANCELLED": AppState.CANCELLED,
    "COMPLETED": AppState.SUCCEEDED,
    "DEADLINE": AppState.FAILED,
    "FAILED": AppState.FAILED,
    "NODE_FAIL": AppState.FAILED,
    "OUT_OF_MEMORY": AppState.FAILED,
    "PENDING": AppState.PENDING,
    "PREEMPTED": AppState.FAILED,
    "RUNNING": AppState.RUNNING,
    "REQUEUED": AppState.PENDING,
    "RESIZING": AppState.PENDING,
    "REVOKED": AppState.FAILED,
    "SUSPENDED": AppState.PENDING,
    "TIMEOUT": AppState.FAILED,
}


def _slurm_escape(s: str) -> str:
    """
    _slurm_escape escapes the argument and substitutes in the macros.app_id with
    a shell expression that fills in SLURM_JOB_ID from env.
    """
    escaped_parts = [shlex.quote(part) for part in s.split(macros.app_id)]
    return '"$SLURM_JOB_ID"'.join(escaped_parts)


@dataclass
class SlurmReplicaRequest:
    """
    Holds parameters for a single replica running on slurm and can be materialized down to a bash script.
    """

    dir: str
    entrypoint: str
    args: List[str]
    opts: Dict[str, str]
    env: Dict[str, str]

    @classmethod
    def from_role(cls, role: Role, cfg: RunConfig) -> "SlurmReplicaRequest":
        opts = {k: str(v) for k, v in cfg.cfgs.items()}
        resource = role.resource

        if resource != NONE:
            if resource.cpu > 0:
                opts["cpus-per-task"] = str(resource.cpu)
            if resource.memMB > 0:
                opts["mem"] = str(resource.memMB)
            if resource.gpu > 0:
                opts["gpus-per-task"] = str(resource.gpu)

        return cls(
            dir=role.image,
            entrypoint=role.entrypoint,
            args=list(role.args),
            opts=opts,
            env=dict(role.env),
        )

    def materialize(self) -> str:
        sbatch_opts = [f"#SBATCH --{key}={value}" for key, value in self.opts.items()]
        sbatch_opts += [
            f"#SBATCH --export={key}={value}" for key, value in self.env.items()
        ]
        sbatch_opts_str = "\n".join(sbatch_opts)

        escaped_args = [_slurm_escape(arg) for arg in self.args]

        return f"""#!/bin/sh
{sbatch_opts_str}

# exit on error
set -e

srun --chdir={self.dir} {self.entrypoint} {" ".join(escaped_args)}
"""


@dataclass
class SlurmBatchRequest:
    """
    Holds parameters used to launch a slurm job via sbatch.
    """

    cmd: List[str]
    replicas: Dict[str, SlurmReplicaRequest]


[docs]class SlurmScheduler(Scheduler): """ SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. Each app def is scheduled using a heterogenous job via sbatch. Each replica of each role has a unique shell script generated with it's resource allocations and args and then sbatch is used to launch all of them together. Logs are written to the default slurm log file. Any scheduler options passed to it are added as SBATCH arguments to each replica. For more info see: * https://slurm.schedmd.com/sbatch.html * https://slurm.schedmd.com/heterogeneous_jobs.html .. code-block:: bash $ torchx run --scheduler slurm utils.echo --msg hello slurm://torchx_user/1234 $ torchx status slurm://torchx_user/1234 $ less slurm-1234.out ... """ def __init__(self, session_name: str) -> None: super().__init__("slurm", session_name)
[docs] def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: req = dryrun_info.request with tempfile.TemporaryDirectory() as tmpdir: for i, (name, body) in enumerate(req.replicas.items()): path = os.path.join(tmpdir, name) with open(path, "w") as f: f.write(body.materialize()) if i > 0: req.cmd.append(":") req.cmd.append(path) p = subprocess.run(req.cmd, stdout=subprocess.PIPE, check=True) return p.stdout.decode("utf-8").strip()
def _submit_dryrun( self, app: AppDef, cfg: RunConfig ) -> AppDryRunInfo[SlurmBatchRequest]: cmd = ["sbatch", "--parsable", "--job-name", app.name] replicas = {} for i, role in enumerate(app.roles): for replica_id in range(role.num_replicas): values = macros.Values( img_root=role.image, app_id=macros.app_id, replica_id=str(replica_id), ) name = f"role-{i}-{role.name}-{replica_id}.sh" replica_role = values.apply(role) replicas[name] = SlurmReplicaRequest.from_role(replica_role, cfg) req = SlurmBatchRequest( cmd=cmd, replicas=replicas, ) return AppDryRunInfo(req, repr) def _validate(self, app: AppDef, scheduler: SchedulerBackend) -> None: # Skip validation step for slurm pass def _cancel_existing(self, app_id: str) -> None: subprocess.run(["scancel", app_id], check=True)
[docs] def describe(self, app_id: str) -> Optional[DescribeAppResponse]: p = subprocess.run( ["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True ) output = p.stdout.decode("utf-8").split("\n") if len(output) <= 1: return None reader = csv.DictReader(output, delimiter="|") resp = DescribeAppResponse( app_id=app_id, ) for row in reader: if row["JobID"] == app_id: state = row["State"] resp.msg = state state_enum = SLURM_STATES.get(state) assert ( state_enum ), f"failed to translate slurm state {state} to torchx state" resp.state = state_enum return resp
def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler: return SlurmScheduler( session_name=session_name, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources