Shortcuts

torchx.plugins

Plugin registration, discovery, and diagnostics for TorchX.

Decorate factory functions with register and place them in torchx_plugins.* namespace packages for automatic discovery:

# torchx_plugins/schedulers/my_scheduler.py
from torchx.plugins import register

@register.scheduler()
def my_scheduler(session_name: str, **kwargs) -> Scheduler:
    ...

Discover plugins and print a diagnostic report:

from torchx import plugins
reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
print(reg)

Deprecated since version Entry-point: based registration ([torchx.*] in pyproject.toml) is deprecated. Set TORCHX_NO_ENTRYPOINTS=1 to opt out early.

Core API

torchx.plugins.registry() PluginRegistry[source]

Return the cached PluginRegistry singleton.

The registry lazily discovers plugins per-group on first get() access and caches the results.

Namespace plugins (torchx_plugins.*) are always loaded. Entry points from importlib.metadata are additionally merged in unless TORCHX_NO_ENTRYPOINTS=1 is set.

Returns:

The cached PluginRegistry instance.

Example:

from torchx import plugins

reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
named = reg.get(plugins.PluginType.NAMED_RESOURCE)
print(reg)
class torchx.plugins.PluginRegistry(*, load_entrypoints: bool = True)[source]

Immutable, lazily-populated plugin registry.

Created by registry(). Each plugin group is discovered on first access via get() and cached for subsequent calls.

Usage:

from torchx import plugins

reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
print(reg)

Namespace plugins (torchx_plugins.*) are always loaded. The load_entrypoints flag only controls whether importlib.metadata entry points are additionally merged in.

Parameters:

load_entrypoints – Whether to also load plugins from importlib.metadata entry points. Set to False to disable entry-point loading (namespace plugins are still discovered). Defaults to True for backward compatibility.

get(plugin_type: PluginType) dict[str, Callable[..., Any]][source]

Discover plugins for plugin_type. Cached after first call.

Returns a dict mapping plugin names to their factory callables. Returns an empty dict when no plugins are found (never None).

info() dict[torchx.plugins._registry.PluginType, dict[str, Callable[..., Any]]][source]
info(plugin_type: PluginType) dict[str, Callable[..., Any]]

Return discovered plugins.

When called with no arguments, triggers discovery for all plugin groups and returns a defensive copy of the full cache keyed by PluginType.

When called with a plugin_type, returns a defensive copy of the plugins dict for that single group.

Example:

>>> from torchx import plugins
>>> all_plugins = plugins.registry().info()  
>>> scheds = plugins.registry().info(plugins.PluginType.SCHEDULER)
class torchx.plugins.PluginType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]

Type of TorchX plugin.

Values are the entry-point group names (e.g. "torchx.schedulers").

This enum is used internally to tag decorated functions and to filter plugins during discovery. It is part of the public API only to support advanced subclassing of register.

Registration

class torchx.plugins.register(type: PluginType, name: str | None = None)[source]

Decorator that tags a function as a TorchX plugin.

Sets _plugin_type and _plugin_name attributes on the decorated function. The discovery scanner (PluginRegistry) imports each submodule under torchx_plugins.* and collects any callable with a _plugin_type attribute.

Usage:

from torchx.plugins import register

@register.scheduler()
def my_scheduler(session_name: str, **kwargs) -> Scheduler:
    ...

@register.scheduler(name="custom_name")
def create_custom(session_name: str, **kwargs) -> Scheduler:
    ...

Each PluginType has a corresponding classmethod: scheduler(), tracker(), and named_resource().

The explicit constructor register(PluginType.SCHEDULER, name=...) is still supported for advanced use-cases.

Fractional resource helpers

powers_of_two_gpus() and halve_mem_down_to() are available as staticmethods so plugin authors can use them with a single import:

@register.named_resource(fractionals=register.powers_of_two_gpus)
def my_gpu(fractional: float = 1.0) -> Resource:
    ...

@register.named_resource(fractionals=register.halve_mem_down_to(minGiB=16))
def my_cpu(fractional: float = 1.0) -> Resource:
    ...
param type:

The PluginType declaring what kind of plugin this function provides.

param name:

Plugin name. Defaults to the decorated function’s __name__.

static halve_mem_down_to(*, minGiB: int) Callable[[...], dict[float, str]][source]

Generate fractional suffixes by halving memory in GiB.

Returns a callable that produces a geometric series with r=1/2 starting from the resource’s total memory in GiB down to minGiB. Each entry maps a fractional float to the corresponding memory suffix string.

Example — 64 GiB host with minGiB=8:

{1.0: "64", 0.5: "32", 0.25: "16", 0.125: "8"}

Usage:

@register.named_resource(fractionals=register.halve_mem_down_to(minGiB=16))
def t1(fractional: float = WHOLE) -> Resource: ...
Parameters:

minGiB – Stop generating fractionals when memory drops below this threshold in GiB. Must be >= the odd part of memGiB (i.e. memGiB with all factors of 2 removed), otherwise halving would produce non-integer GiB values.

Raises:

ValueError – if resource.memMB is zero, not GiB-aligned, or minGiB is below the odd part of memGiB.

classmethod named_resource(name: str | None = None, aliases: list[str] | None = None, fractionals: Optional[Union[Callable[[...], dict[float, str]], dict[float, str]]] = None) _register_named_resource[source]

Register a named resource factory.

Parameters:
  • name – Resource name. Defaults to the decorated function’s __name__.

  • aliases – Additional names that point to the same factory.

  • fractionals – Either a callable (Resource) {fraction: suffix} or a literal dict. When provided, the decorated function must accept a fractional: float parameter.

static powers_of_two_gpus(resource: Any) dict[float, str][source]

Return fractional specs for every power-of-two GPU slice of resource.

For example, an 8-GPU resource produces:

{1.0: "8", 0.5: "4", 0.25: "2", 0.125: "1"}

When passed as the fractionals argument of register.named_resource(), this auto-generates and registers all power-of-two fractional variants (e.g. my_gpu_8, my_gpu_4, …).

Parameters:

resource – The base (whole-host) resource. Must have gpu > 0 and gpu must be a power of two.

Raises:

ValueError – if resource.gpu is zero or not a power of two.

classmethod scheduler(name: str | None = None) register[source]

Register a Scheduler factory.

classmethod tracker(name: str | None = None) register[source]

Register a tracker factory.

Constants

torchx.plugins.WHOLE = 1.0

Convert a string or number to a floating point number, if possible.

torchx.plugins.HALF = 0.5

Convert a string or number to a floating point number, if possible.

torchx.plugins.QUARTER = 0.25

Convert a string or number to a floating point number, if possible.

torchx.plugins.EIGHTH = 0.125

Convert a string or number to a floating point number, if possible.

torchx.plugins.SIXTEENTH = 0.0625

Convert a string or number to a floating point number, if possible.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources