Rate this Page

Source code for monarch._src.job.slurm

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import json
import logging
import os
import subprocess
import sys
from typing import Any, Dict, FrozenSet, List, Optional, Sequence

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._src.actor.bootstrap import attach_to_workers
from monarch._src.job.job import JobState, JobTrait


logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stderr))
logger.propagate = False

# terminal states that indicate the job is no longer active
_SLURM_TERMINAL_STATES: FrozenSet[str] = frozenset(
    ["FAILED", "CANCELLED", "TIMEOUT", "PREEMPTED", "COMPLETED"]
)


[docs] class SlurmJob(JobTrait): """ A job scheduler that uses SLURM command line tools to schedule jobs. This implementation: 1. Uses sbatch to submit SLURM jobs that start monarch workers 2. Queries job status with squeue to get allocated hostnames 3. Uses the hostnames to connect to the started workers """ def __init__( self, meshes: Dict[str, int], python_exe: str = "python", slurm_args: Sequence[str] = (), monarch_port: int = 22222, job_name: str = "monarch_job", ntasks_per_node: int = 1, time_limit: Optional[str] = None, partition: Optional[str] = None, log_dir: Optional[str] = None, exclusive: bool = True, gpus_per_node: Optional[int] = None, cpus_per_task: Optional[int] = None, mem: Optional[str] = None, job_start_timeout: Optional[int] = None, ) -> None: """ Args: meshes: Dictionary mapping mesh names to number of nodes python_exe: Python executable to use for worker processes slurm_args: Additional SLURM arguments to pass to sbatch monarch_port: Port for TCP communication between workers job_name: Name for the SLURM job ntasks_per_node: Number of tasks per node time_limit: Maximum runtime in HH:MM:SS format. If None, uses SLURM's default time limit. partition: SLURM partition to submit to log_dir: Directory for SLURM log files exclusive: Whether to request exclusive node access (no other jobs can run on the nodes). Defaults to True for predictable performance and resource isolation, but may increase queue times and waste resources if nodes are underutilized. gpus_per_node: Number of GPUs to request per node. If None, no GPU resources are requested. job_start_timeout: Maximum time in seconds to wait for the SLURM job to start running. This should account for potential queueing delays. If None (default), waits indefinitely. """ configure(default_transport=ChannelTransport.TcpWithHostname) self._meshes = meshes self._python_exe = python_exe self._slurm_args = slurm_args self._port = monarch_port self._job_name = job_name self._ntasks_per_node = ntasks_per_node self._time_limit = time_limit self._partition = partition self._log_dir: str = log_dir if log_dir is not None else os.getcwd() self._exclusive = exclusive self._gpus_per_node = gpus_per_node self._cpus_per_task = cpus_per_task self._mem = mem self._job_start_timeout = job_start_timeout # Track the single SLURM job ID and all allocated hostnames self._slurm_job_id: Optional[str] = None self._all_hostnames: List[str] = [] super().__init__()
[docs] def add_mesh(self, name: str, num_nodes: int) -> None: self._meshes[name] = num_nodes
def _create(self, client_script: Optional[str]) -> None: """Submit a single SLURM job for all meshes.""" if client_script is not None: raise RuntimeError("SlurmJob cannot run batch-mode scripts") total_nodes = sum(self._meshes.values()) self._slurm_job_id = self._submit_slurm_job(total_nodes) def _submit_slurm_job(self, num_nodes: int) -> str: """Submit a SLURM job for all nodes.""" unique_job_name = f"{self._job_name}_{os.getpid()}" # Create log directory if it doesn't exist os.makedirs(self._log_dir, exist_ok=True) log_path_out = os.path.join(self._log_dir, f"slurm_%j_{unique_job_name}.out") log_path_err = os.path.join(self._log_dir, f"slurm_%j_{unique_job_name}.err") python_command = f'import socket; from monarch.actor import run_worker_loop_forever; hostname = socket.gethostname(); run_worker_loop_forever(address=f"tcp://{{hostname}}:{self._port}", ca="trust_all_connections")' # Build SBATCH directives sbatch_directives = [ "#!/bin/bash", f"#SBATCH --job-name={unique_job_name}", f"#SBATCH --ntasks-per-node={self._ntasks_per_node}", f"#SBATCH --nodes={num_nodes}", f"#SBATCH --output={log_path_out}", f"#SBATCH --error={log_path_err}", ] if self._time_limit is not None: sbatch_directives.append(f"#SBATCH --time={self._time_limit}") if self._gpus_per_node is not None: sbatch_directives.append(f"#SBATCH --gpus-per-node={self._gpus_per_node}") if self._cpus_per_task is not None: sbatch_directives.append(f"#SBATCH --cpus-per-task={self._cpus_per_task}") if self._mem is not None: sbatch_directives.append(f"#SBATCH --mem={self._mem}") if self._exclusive: sbatch_directives.append("#SBATCH --exclusive") if self._partition is not None: sbatch_directives.append(f"#SBATCH --partition={self._partition}") if ( not self._exclusive and self._partition is not None and self._gpus_per_node is not None ): gpus_per_task = self._gpus_per_node // self._ntasks_per_node assert self._partition, ( "Slurm partition must be set for jobs that share nodes with other jobs" ) self.share_node( tasks_per_node=self._ntasks_per_node, gpus_per_task=gpus_per_task, partition=self._partition, ) # Add any additional slurm args as directives for arg in self._slurm_args: if arg.startswith("-"): sbatch_directives.append(f"#SBATCH {arg}") batch_script = "\n".join(sbatch_directives) batch_script += f"\nsrun {self._python_exe} -c '{python_command}'\n" logger.info(f"Submitting SLURM job with {num_nodes} nodes") try: result = subprocess.run( ["sbatch"], input=batch_script, capture_output=True, text=True, check=True, ) # Parse the job ID from sbatch output (typically "Submitted batch job 12345") job_id = None for line in result.stdout.strip().split("\n"): if "Submitted batch job" in line: job_id = line.split()[-1] break if not job_id: raise RuntimeError( f"Failed to parse job ID from sbatch output: {result.stdout}" ) logger.info( f"SLURM job {job_id} submitted. Logs will be written to: {self._log_dir}/slurm_{job_id}_{unique_job_name}.out" ) return job_id except subprocess.CalledProcessError as e: raise RuntimeError(f"Failed to submit SLURM job: {e.stderr}") from e def _get_job_info_json(self, job_id: str) -> Optional[Dict[str, Any]]: """Get job information using squeue --json.""" try: result = subprocess.run( ["squeue", "--job", job_id, "--json"], capture_output=True, text=True, check=True, ) if result.stdout.strip(): data = json.loads(result.stdout) jobs = data.get("jobs", []) return jobs[0] if jobs else None return None except subprocess.CalledProcessError as e: logger.warning(f"Error checking job {job_id} status: {e.stderr}") return None except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Error parsing JSON response for job {job_id}: {e}") return None def _wait_for_job_start( self, job_id: str, expected_nodes: int, timeout: Optional[int] = None ) -> List[str]: """ Wait for the SLURM job to start and return the allocated hostnames. Requires Slurm 20.02+ for squeue --json support. """ import time start_time = time.time() try: while timeout is None or time.time() - start_time < timeout: job_info = self._get_job_info_json(job_id) if not job_info: raise RuntimeError(f"SLURM job {job_id} not found in queue") job_state = job_info.get("job_state", []) if "RUNNING" in job_state: # Extract hostnames from job_resources.nodes.allocation job_resources = job_info.get("job_resources", {}) nodes_info = job_resources.get("nodes", {}) allocation = nodes_info.get("allocation", []) hostnames = [node["name"] for node in allocation] logger.info( f"SLURM job {job_id} is running on {len(hostnames)} nodes: {hostnames}" ) if len(hostnames) != expected_nodes: raise RuntimeError( f"Expected {expected_nodes} nodes but got {len(hostnames)}. " f"Partial allocation not supported." ) return hostnames elif any(state in job_state for state in _SLURM_TERMINAL_STATES): raise RuntimeError( f"SLURM job {job_id} failed with status: {job_state}" ) else: logger.debug(f"SLURM job {job_id} status: {job_state}, waiting...") time.sleep(2) # Check every 2 seconds raise RuntimeError(f"Timeout waiting for SLURM job {job_id} to start") except Exception: # Cleanup on failure - reuse _kill() logic logger.error(f"Failed to start SLURM job {job_id}, cancelling job") self._kill() raise def _state(self) -> JobState: if not self._jobs_active(): raise RuntimeError("SLURM job is no longer active") # Wait for job to start and get hostnames if not already done if not self._all_hostnames: job_id = self._slurm_job_id if job_id is None: raise RuntimeError("SLURM job ID is not set") total_nodes = sum(self._meshes.values()) self._all_hostnames = self._wait_for_job_start( job_id, total_nodes, timeout=self._job_start_timeout ) # Distribute the allocated hostnames among meshes host_meshes = {} hostname_idx = 0 for mesh_name, num_nodes in self._meshes.items(): mesh_hostnames = self._all_hostnames[ hostname_idx : hostname_idx + num_nodes ] hostname_idx += num_nodes workers = [f"tcp://{hostname}:{self._port}" for hostname in mesh_hostnames] host_mesh = attach_to_workers( name=mesh_name, ca="trust_all_connections", workers=workers, # type: ignore[arg-type] ) host_meshes[mesh_name] = host_mesh return JobState(host_meshes)
[docs] def can_run(self, spec: "JobTrait") -> bool: """Check if this job can run the given spec.""" return ( isinstance(spec, SlurmJob) and spec._meshes == self._meshes and spec._python_exe == self._python_exe and spec._port == self._port and spec._slurm_args == self._slurm_args and spec._job_name == self._job_name and spec._ntasks_per_node == self._ntasks_per_node and spec._time_limit == self._time_limit and spec._partition == self._partition and spec._gpus_per_node == self._gpus_per_node and spec._cpus_per_task == self._cpus_per_task and spec._mem == self._mem and spec._job_start_timeout == self._job_start_timeout and self._jobs_active() )
def _jobs_active(self) -> bool: """Check if SLURM job is still active by querying squeue.""" if not self.active or self._slurm_job_id is None: return False job_info = self._get_job_info_json(self._slurm_job_id) if not job_info: logger.warning(f"SLURM job {self._slurm_job_id} not found in queue") return False job_state = job_info.get("job_state", []) if any(state in job_state for state in _SLURM_TERMINAL_STATES): logger.warning(f"SLURM job {self._slurm_job_id} has status: {job_state}") return False return True
[docs] def share_node( self, tasks_per_node: int, gpus_per_task: int, partition: str ) -> None: """ Share a node with other jobs. """ try: import clusterscope except ImportError: raise RuntimeError( "please install clusterscope to use share_node. `pip install clusterscope`" ) self._exclusive = False slurm_args = clusterscope.job_gen_task_slurm( partition=partition, gpus_per_task=gpus_per_task, tasks_per_node=tasks_per_node, ) self._cpus_per_task = slurm_args["cpus_per_task"] self._mem = slurm_args["memory"]
def _kill(self) -> None: """Cancel the SLURM job.""" if self._slurm_job_id is not None: try: subprocess.run( ["scancel", self._slurm_job_id], capture_output=True, text=True, check=True, ) logger.info(f"Cancelled SLURM job {self._slurm_job_id}") except subprocess.CalledProcessError as e: logger.warning( f"Failed to cancel SLURM job {self._slurm_job_id}: {e.stderr}" ) self._slurm_job_id = None self._all_hostnames.clear()