Shortcuts

Source code for torchx.schedulers.slurm_scheduler

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import csv
import os.path
import shlex
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, Scheduler
from torchx.specs.api import (
    NONE,
    AppDef,
    AppState,
    Role,
    RunConfig,
    SchedulerBackend,
    macros,
    RoleStatus,
    ReplicaStatus,
    runopts,
)


SLURM_STATES: Mapping[str, AppState] = {
    "BOOT_FAIL": AppState.FAILED,
    "CANCELLED": AppState.CANCELLED,
    "COMPLETED": AppState.SUCCEEDED,
    "DEADLINE": AppState.FAILED,
    "FAILED": AppState.FAILED,
    "NODE_FAIL": AppState.FAILED,
    "OUT_OF_MEMORY": AppState.FAILED,
    "PENDING": AppState.PENDING,
    "PREEMPTED": AppState.FAILED,
    "RUNNING": AppState.RUNNING,
    "REQUEUED": AppState.PENDING,
    "RESIZING": AppState.PENDING,
    "REVOKED": AppState.FAILED,
    "SUSPENDED": AppState.PENDING,
    "TIMEOUT": AppState.FAILED,
}


def _apply_app_id_env(s: str) -> str:
    """
    _apply_app_id_env escapes the argument and substitutes in the macros.app_id with
    a shell expression that fills in SLURM_JOB_ID from env.
    """
    escaped_parts = [shlex.quote(part) for part in s.split(macros.app_id)]
    return '"$SLURM_JOB_ID"'.join(escaped_parts)


@dataclass
class SlurmReplicaRequest:
    """
    Holds parameters for a single replica running on slurm and can be materialized down to a bash script.
    """

    name: str
    dir: str
    entrypoint: str
    args: List[str]
    srun_opts: Dict[str, str]
    sbatch_opts: Dict[str, str]
    env: Dict[str, str]

    @classmethod
    def from_role(cls, name: str, role: Role, cfg: RunConfig) -> "SlurmReplicaRequest":
        sbatch_opts = {k: str(v) for k, v in cfg.cfgs.items() if v is not None}
        sbatch_opts.setdefault("ntasks-per-node", "1")
        resource = role.resource

        if resource != NONE:
            if resource.cpu > 0:
                sbatch_opts.setdefault("cpus-per-task", str(resource.cpu))
            if resource.memMB > 0:
                sbatch_opts.setdefault("mem", str(resource.memMB))
            if resource.gpu > 0:
                sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))

        return cls(
            name=name,
            dir=role.image,
            entrypoint=role.entrypoint,
            args=list(role.args),
            sbatch_opts=sbatch_opts,
            srun_opts={},
            env=dict(role.env),
        )

    def materialize(self) -> Tuple[List[str], List[str]]:
        """
        materialize returns the sbatch and srun groups for this role. They
        should be combined using `:` per slurm heterogenous groups.
        """
        sbatch_args = [
            f"--job-name={self.name}",
        ] + [f"--{key}={value}" for key, value in self.sbatch_opts.items()]
        srun_args = (
            [f"--chdir={self.dir}"]
            + [f"--{key}={value}" for key, value in self.srun_opts.items()]
            + [f"--export={key}={value}" for key, value in self.env.items()]
        )

        srun_group = srun_args + [self.entrypoint] + self.args
        srun_group = [_apply_app_id_env(arg) for arg in srun_group]

        return sbatch_args, srun_group


@dataclass
class SlurmBatchRequest:
    """
    Holds parameters used to launch a slurm job via sbatch.
    """

    cmd: List[str]
    replicas: Dict[str, SlurmReplicaRequest]

    def materialize(self) -> str:
        """
        materialize returns the contents of the script that can be passed to
        sbatch to run the job.
        """

        sbatch_groups = []
        srun_groups = []
        for i, replica in enumerate(self.replicas.values()):
            if i > 0:
                srun_groups.append(":\\\n    ")

            sbatch_group, srun_group = replica.materialize()
            sbatch_groups.append(" ".join(sbatch_group))
            srun_groups += srun_group

        sbatch_opts = "#SBATCH hetjob\n".join(
            f"#SBATCH {group}\n" for group in sbatch_groups
        )
        script = f"""#!/bin/bash
{sbatch_opts}
# exit on error
set -e

srun {" ".join(srun_groups)}
"""
        sbatch_cmd = self.cmd + sbatch_groups
        return script


[docs]class SlurmScheduler(Scheduler): """ SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. Each app def is scheduled using a heterogenous job via sbatch. Each replica of each role has a unique shell script generated with it's resource allocations and args and then sbatch is used to launch all of them together. Logs are written to the default slurm log file. Any scheduler options passed to it are added as SBATCH arguments to each replica. See https://slurm.schedmd.com/sbatch.html#SECTION_OPTIONS for info on the arguments. For more info see: * https://slurm.schedmd.com/sbatch.html * https://slurm.schedmd.com/heterogeneous_jobs.html .. code-block:: bash $ torchx run --scheduler slurm utils.echo --msg hello slurm://torchx_user/1234 $ torchx status slurm://torchx_user/1234 $ less slurm-1234.out ... .. compatibility:: type: scheduler features: cancel: true logs: | Logs are accessible via the default slurm log file but not the programmatic API. distributed: true describe: | Partial support. SlurmScheduler will return job and replica status but does not provide the complete original AppSpec. """ def __init__(self, session_name: str) -> None: super().__init__("slurm", session_name)
[docs] def run_opts(self) -> runopts: opts = runopts() opts.add( "partition", type_=str, help="The partition to run the job in.", default=None, ) opts.add( "time", type_=str, default=None, help="The maximum time the job is allowed to run for.", ) return opts
[docs] def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: req = dryrun_info.request with tempfile.TemporaryDirectory() as tmpdir: script = req.materialize() path = os.path.join(tmpdir, "job.sh") with open(path, "w") as f: f.write(script) cmd = req.cmd + [path] p = subprocess.run(cmd, stdout=subprocess.PIPE, check=True) return p.stdout.decode("utf-8").strip()
def _submit_dryrun( self, app: AppDef, cfg: RunConfig ) -> AppDryRunInfo[SlurmBatchRequest]: replicas = {} for role in app.roles: for replica_id in range(role.num_replicas): values = macros.Values( img_root=role.image, app_id=macros.app_id, replica_id=str(replica_id), ) name = f"{app.name}-{role.name}-{replica_id}" replica_role = values.apply(role) replicas[name] = SlurmReplicaRequest.from_role(name, replica_role, cfg) req = SlurmBatchRequest( cmd=["sbatch", "--parsable"], replicas=replicas, ) return AppDryRunInfo(req, repr) def _validate(self, app: AppDef, scheduler: SchedulerBackend) -> None: # Skip validation step for slurm pass def _cancel_existing(self, app_id: str) -> None: subprocess.run(["scancel", app_id], check=True)
[docs] def describe(self, app_id: str) -> Optional[DescribeAppResponse]: p = subprocess.run( ["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True ) output = p.stdout.decode("utf-8").split("\n") if len(output) <= 1: return None reader = csv.DictReader(output, delimiter="|") roles = {} roles_statuses = {} msg = "" app_state = AppState.UNKNOWN for row in reader: job_id, *parts = row["JobID"].split("+") if job_id != app_id: continue if len(parts) > 0 and "." in parts[0]: # we only care about the worker not the child jobs continue state = row["State"] msg = state state_enum = SLURM_STATES.get(state) assert ( state_enum ), f"failed to translate slurm state {state} to torchx state" app_state = state_enum name_parts = row["JobName"].split("-") if len(name_parts) < 3: # name should always have at least 3 parts but sometimes sacct # is slow to update continue role = name_parts[-2] replica_id = int(name_parts[-1]) if role not in roles: roles[role] = Role(name=role, num_replicas=0, image="") roles_statuses[role] = RoleStatus(role, []) roles[role].num_replicas += 1 roles_statuses[role].replicas.append( ReplicaStatus(id=replica_id, role=role, state=app_state, hostname=""), ) return DescribeAppResponse( app_id=app_id, roles=list(roles.values()), roles_statuses=list(roles_statuses.values()), state=app_state, msg=msg, )
def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler: return SlurmScheduler( session_name=session_name, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources