Source code for torchx.schedulers.api
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import abc
import inspect
import re
import types
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, field, fields, MISSING
from datetime import datetime
from enum import Enum
from typing import (
Generic,
get_args,
get_type_hints,
Iterable,
List,
Optional,
TypeVar,
Union,
)
from torchx.specs import (
AppDef,
AppDryRunInfo,
AppState,
cases,
CfgVal,
NONE,
NULL_RESOURCE,
Role,
RoleStatus,
runopts,
Workspace,
)
from torchx.workspace import WorkspaceMixin
from typing_extensions import Self
DAYS_IN_2_WEEKS = 14
# =============================================================================
# STRUCTURED OPTIONS BASE CLASS
# =============================================================================
# pyre-fixme[24]: Generic type `type` expects 1 type parameter.
def _unwrap_optional(tp: type) -> type:
"""Strip ``None`` from union types (e.g. ``str | None`` -> ``str``)."""
args = [a for a in get_args(tp) if a is not types.NoneType]
if args and len(args) < len(get_args(tp)):
return args[0] if len(args) == 1 else Union[tuple(args)]
return tp
# pyre-fixme[24]: Generic type `type` expects 1 type parameter.
def _is_structured_opts(tp: type) -> bool:
"""Return True if *tp* is a concrete ``StructuredOpts`` subclass."""
try:
return (
isinstance(tp, type)
and issubclass(tp, StructuredOpts)
and tp is not StructuredOpts
)
except TypeError:
# Generic aliases like list[str] or dict[str, str] can pass
# isinstance(tp, type) on some Python versions but fail issubclass().
return False
[docs]class StructuredOpts(Mapping[str, CfgVal]):
"""Base class for typed scheduler configuration options.
Provides a type-safe way to define scheduler run options as dataclass fields
instead of manually building :py:class:`~torchx.specs.runopts`. Subclasses
should be ``@dataclass`` decorated with fields representing config options.
Features:
- Auto-generates ``runopts`` from dataclass fields via :py:meth:`as_runopts`
- Parses raw config dicts into typed instances via :py:meth:`from_cfg`
- Supports snake_case field names with camelCase aliases
- Extracts help text from field docstrings
- Supports nested ``StructuredOpts`` fields, flattened with dot-prefixed
keys (e.g., ``k8s.context``)
Example:
.. doctest::
>>> from dataclasses import dataclass
>>> from torchx.schedulers.api import StructuredOpts
>>>
>>> @dataclass
... class MyOpts(StructuredOpts):
... cluster_name: str
... '''Name of the cluster to submit to.'''
...
... num_retries: int = 3
... '''Number of retry attempts.'''
...
>>> # Use in scheduler:
>>> # def _run_opts(self) -> runopts:
>>> # return MyOpts.as_runopts()
>>> #
>>> # def _submit_dryrun(self, app, cfg):
>>> # opts = MyOpts.from_cfg(cfg)
>>> # # opts.cluster_name, opts.num_retries are typed
"""
[docs] @classmethod
def from_cfg(cls, cfg: Mapping[str, CfgVal]) -> Self:
"""Create an instance from a raw config dict.
Fields are snake_case but also accept camelCase aliases (e.g.,
``hpc_identity`` can be set via ``hpcIdentity``).
Nested :py:class:`StructuredOpts` fields are reconstructed from
dot-prefixed keys (e.g., ``k8s.context``).
"""
type_hints = get_type_hints(cls)
kwargs = {}
for f in fields(cls):
name = f.name
field_type = _unwrap_optional(type_hints.get(name, str))
if _is_structured_opts(field_type):
prefix = f"{name}."
nested_cfg = {
k[len(prefix) :]: v for k, v in cfg.items() if k.startswith(prefix)
}
if nested_cfg:
kwargs[name] = field_type.from_cfg(nested_cfg)
elif f.default is MISSING and f.default_factory is MISSING:
# Required nested group — construct so its own validation runs.
kwargs[name] = field_type.from_cfg({})
continue
if name in cfg:
kwargs[name] = cfg[name]
else:
camel_case = cases.snake_to_camel(name)
if camel_case in cfg:
kwargs[name] = cfg[camel_case]
return cls(**kwargs)
# -------------------------------------------------------------------------
# Mapping Protocol Methods (for backwards compatibility)
#
# These methods allow StructuredOpts instances to be used in places that
# expect a dict-like interface (e.g., plugins that do cfg.get("key") or
# cfg["key"]). Once all plugins are migrated to use typed field access
# (e.g., cfg.field_name), these methods can be removed.
#
# TODO(T252193642): Remove these methods after migrating plugins to use
# StructuredOpts field access instead of dict-like access.
# -------------------------------------------------------------------------
# pyre-fixme[14]: Inconsistent override - Mapping.get accepts a default parameter
def __getitem__(self, key: str) -> CfgVal:
if "." in key:
prefix, rest = key.split(".", 1)
prefix = cases.camel_to_snake(prefix)
nested = getattr(self, prefix, None)
if isinstance(nested, StructuredOpts):
return nested[rest]
raise KeyError(key) from None
snake_key = cases.camel_to_snake(key)
if hasattr(self, snake_key):
return getattr(self, snake_key)
raise KeyError(key) from None
def __len__(self) -> int:
return sum(1 for _ in self)
def __iter__(self) -> Iterator[str]:
type_hints = get_type_hints(type(self))
for f in fields(self):
field_type = _unwrap_optional(type_hints.get(f.name, str))
if _is_structured_opts(field_type):
nested = getattr(self, f.name)
if nested is not None:
for nested_key in nested:
yield f"{f.name}.{nested_key}"
else:
yield f.name
# pyre-fixme[14]: Inconsistent override - Mapping uses PyreReadOnly[object]
def __contains__(self, key: object) -> bool:
if not isinstance(key, str):
return False
try:
self[key]
except KeyError:
return False
return True
@classmethod
def get_docstrings(cls) -> dict[str, str]:
# Parses source to extract attribute docstrings for help text.
docstrings: dict[str, str] = {}
try:
source = inspect.getsource(cls)
except (OSError, TypeError):
return docstrings
# Match: field_name: type...\n """docstring"""
pattern = re.compile(
r'^\s+(\w+):\s*[^\n]+\n\s+"""([^"]+)"""',
re.MULTILINE,
)
for match in pattern.finditer(source):
field_name = match.group(1)
docstring = match.group(2).strip()
docstrings[field_name] = docstring
type_hints = get_type_hints(cls)
for f in fields(cls):
field_type = _unwrap_optional(type_hints.get(f.name, str))
if _is_structured_opts(field_type):
for key, doc in field_type.get_docstrings().items():
docstrings[f"{f.name}.{key}"] = doc
return docstrings
[docs] @classmethod
def as_runopts(cls) -> runopts:
"""Build :py:class:`~torchx.specs.runopts` from dataclass fields.
Nested :py:class:`StructuredOpts` fields are flattened with
dot-prefixed keys (e.g., field ``k8s: K8sOpts`` with sub-field
``context`` becomes ``k8s.context``).
"""
opts = runopts()
type_hints = get_type_hints(cls)
docstrings = cls.get_docstrings()
for f in fields(cls):
name = f.name
field_type = _unwrap_optional(type_hints.get(name, str))
if _is_structured_opts(field_type):
nested_opts = field_type.as_runopts()
for nested_key, nested_runopt in nested_opts:
opts.add(
f"{name}.{nested_key}",
type_=nested_runopt.opt_type,
default=nested_runopt.default,
required=nested_runopt.is_required,
help=nested_runopt.help,
)
continue
help_text = docstrings.get(name, name)
type_ = field_type
has_default = f.default is not MISSING
has_default_factory = f.default_factory is not MISSING
if has_default:
default = f.default
elif has_default_factory:
default = None # Don't call factory, just indicate no default
else:
default = None
required = not has_default and not has_default_factory
opts.add(
name,
type_=type_,
default=default,
required=required,
help=help_text,
)
return opts
# pyre-fixme[15]: Inconsistent override - __or__ returns dict, not UnionType
def __or__(self, other: StructuredOpts) -> dict[str, CfgVal]:
"""Merge two StructuredOpts instances into a cfg dict.
Example:
.. doctest::
>>> from dataclasses import dataclass
>>> from torchx.schedulers.api import StructuredOpts
>>> @dataclass
... class OptsA(StructuredOpts):
... foo: str = "a"
>>> @dataclass
... class OptsB(StructuredOpts):
... bar: int = 1
>>> cfg = OptsA(foo="x") | OptsB(bar=2)
>>> cfg["foo"], cfg["bar"]
('x', 2)
"""
merged: dict[str, CfgVal] = {}
for key in self:
merged[key] = self[key]
for key in other:
merged[key] = other[key]
return merged
# =============================================================================
# STREAM AND RESPONSE TYPES
# =============================================================================
[docs]@dataclass
class DescribeAppResponse:
"""Response from :py:meth:`Scheduler.describe`. Contains status, roles, and metadata."""
app_id: str = "<NOT_SET>"
state: AppState = AppState.UNSUBMITTED
num_restarts: int = -1
msg: str = NONE
structured_error_msg: str = NONE
ui_url: Optional[str] = None
metadata: dict[str, str] = field(default_factory=dict)
roles_statuses: List[RoleStatus] = field(default_factory=list)
roles: List[Role] = field(default_factory=list)
[docs]@dataclass
class ListAppResponse:
"""Response from :py:meth:`Scheduler.list` / :py:meth:`~torchx.runner.api.Runner.list`."""
app_id: str
state: AppState
app_handle: str = "<NOT_SET>"
name: str = ""
# Implementing __hash__() makes ListAppResponse hashable which makes
# it easier to check if a ListAppResponse object exists in a list of
# objects for testing purposes.
def __hash__(self) -> int:
return hash((self.app_id, self.app_handle, self.state))
T = TypeVar("T")
[docs]class Scheduler(abc.ABC, Generic[T]):
"""Abstract base class for job schedulers.
Implementors must override all ``@abc.abstractmethod`` methods.
See :py:class:`StructuredOpts` for typed config and
:py:mod:`torchx.schedulers` for built-in implementations.
"""
def __init__(self, backend: str, session_name: str) -> None:
self.backend = backend
self.session_name = session_name
[docs] def close(self) -> None:
"""Releases local resources. Safe to call multiple times.
Only override for schedulers with local state (e.g. ``local_scheduler``).
"""
pass
[docs] def submit(
self,
app: AppDef,
cfg: T,
workspace: str | Workspace | None = None,
) -> str:
"""Submits an app directly. Prefer :py:meth:`~torchx.runner.api.Runner.run` for production use."""
# pyre-fixme: Generic cfg type passed to resolve
resolved_cfg = self.run_opts().resolve(cfg)
if workspace:
assert isinstance(self, WorkspaceMixin)
if isinstance(workspace, str):
workspace = Workspace.from_str(workspace)
app.roles[0].workspace = workspace
self.build_workspaces(app.roles, resolved_cfg)
# pyre-fixme: submit_dryrun takes Generic type for resolved_cfg
dryrun_info = self.submit_dryrun(app, resolved_cfg)
return self.schedule(dryrun_info)
[docs] @abc.abstractmethod
def schedule(self, dryrun_info: AppDryRunInfo) -> str:
"""Submits a previously dry-run request. Returns the app_id."""
raise NotImplementedError()
[docs] def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
"""Returns the scheduler request without submitting."""
# pyre-fixme: Generic cfg type passed to resolve
resolved_cfg = self.run_opts().resolve(cfg)
# pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
dryrun_info = self._submit_dryrun(app, resolved_cfg)
for role in app.roles:
dryrun_info = role.pre_proc(self.backend, dryrun_info)
dryrun_info._app = app
dryrun_info._cfg = resolved_cfg
return dryrun_info
@abc.abstractmethod
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
raise NotImplementedError()
[docs] def run_opts(self) -> runopts:
"""Returns accepted run configuration options (``torchx runopts <scheduler>``)."""
opts = self._run_opts()
if isinstance(self, WorkspaceMixin):
opts.update(self.workspace_opts())
return opts
def _run_opts(self) -> runopts:
return runopts()
[docs] @abc.abstractmethod
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
"""Returns app description, or ``None`` if it no longer exists."""
raise NotImplementedError()
[docs] @abc.abstractmethod
def list(self, cfg: Mapping[str, CfgVal] | None = None) -> List[ListAppResponse]:
"""Lists jobs on this scheduler."""
raise NotImplementedError()
def exists(self, app_id: str) -> bool:
desc = self.describe(app_id)
return desc is not None
@abc.abstractmethod
def _cancel_existing(self, app_id: str) -> None:
raise NotImplementedError()
[docs] def cancel(self, app_id: str) -> None:
"""Cancels the app. Idempotent — safe to call multiple times.
Does not block. Use :py:meth:`~torchx.runner.api.Runner.wait` to
await the terminal state.
"""
if self.exists(app_id):
self._cancel_existing(app_id)
else:
# do nothing if the app does not exist
return
[docs] def delete(self, app_id: str) -> None:
"""Deletes the job definition from the scheduler's data-plane.
On schedulers with persistent job definitions (e.g. Kubernetes, AWS Batch),
this purges the definition. On others (e.g. Slurm), this is equivalent to
:py:meth:`cancel`. Calling on a live job cancels it first.
"""
if self.exists(app_id):
self._delete_existing(app_id)
def _delete_existing(self, app_id: str) -> None:
self._cancel_existing(app_id)
[docs] def log_iter(
self,
app_id: str,
role_name: str,
k: int = 0,
regex: Optional[str] = None,
since: Optional[datetime] = None,
until: Optional[datetime] = None,
should_tail: bool = False,
streams: Optional[Stream] = None,
) -> Iterable[str]:
"""Returns an iterator over log lines for the ``k``-th replica of ``role_name``.
.. important:: Not all schedulers support log iteration, tailing, or
time-based cursors. Check the specific scheduler docs.
Lines include trailing whitespace (``\\n``). When ``should_tail=True``,
the iterator blocks until the app reaches a terminal state.
Args:
k: replica (node) index
regex: optional filter pattern
since: start cursor (scheduler-dependent)
until: end cursor (scheduler-dependent)
should_tail: if ``True``, follow output like ``tail -f``
streams: ``stdout``, ``stderr``, or ``combined``
Raises:
NotImplementedError: if the scheduler does not support log iteration
"""
raise NotImplementedError(
f"{self.__class__.__qualname__} does not support application log iteration"
)
def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
# Hook for pre-workspace-build validation. Override to add checks.
pass
def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
# Hook for post-workspace-build validation.
for role in app.roles:
if role.resource == NULL_RESOURCE:
raise ValueError(
f"No resource for role: {role.image}. Did you forget to attach resource to the role"
)
def filter_regex(regex: str, data: Iterable[str]) -> Iterable[str]:
"""Filters an iterable of strings, yielding only lines matching ``regex``."""
r = re.compile(regex)
return filter(lambda datum: r.search(datum), data)
def split_lines(text: str) -> List[str]:
"""Splits ``text`` by newlines, preserving the ``\\n`` characters."""
lines = []
while len(text) > 0:
idx = text.find("\n")
if idx >= 0:
lines.append(text[: idx + 1])
text = text[idx + 1 :]
else:
lines.append(text)
break
return lines
def split_lines_iterator(chunks: Iterable[str]) -> Iterable[str]:
"""Splits each chunk in the iterable by newlines, yielding individual lines."""
for chunk in chunks:
lines = split_lines(chunk)
for line in lines:
yield line