Shortcuts

torchx.plugins

Plugin registration, discovery, and diagnostics for TorchX.

Decorate factory functions with register and place them in torchx_plugins.* namespace packages for automatic discovery:

# torchx_plugins/schedulers/my_scheduler.py
from torchx.plugins import register

@register.scheduler()
def my_scheduler(session_name: str, **kwargs) -> Scheduler:
    ...

Discover plugins and print a diagnostic report:

from torchx import plugins
reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
print(reg)

Deprecated since version Entry-point: based registration ([torchx.*] in pyproject.toml) is deprecated. Set TORCHX_PLUGINS_SOURCE=1 (namespace package only) to opt out of entry-point discovery early.

Core API

torchx.plugins.registry() PluginRegistry[source]

Return the cached PluginRegistry singleton.

The registry lazily discovers plugins per-group on first get() access and caches the results.

The TORCHX_PLUGINS_SOURCE environment variable selects which discovery channels are enabled. Its value is parsed as the integer representation of a PluginSource bitmask: 0 for none, 1 for namespace package only, 2 for entry points only, 3 for both. Defaults to all channels enabled when unset.

Returns:

The cached PluginRegistry instance.

Example:

from torchx import plugins

reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
named = reg.get(plugins.PluginType.NAMED_RESOURCE)
print(reg)
class torchx.plugins.PluginRegistry(*, plugin_sources: PluginSource = PluginSource.NAMESPACE_PKG | ENTRYPOINT)[source]

Immutable, lazily-populated plugin registry.

Created by registry(). Each plugin group is discovered on first access via get() and cached for subsequent calls.

Usage:

from torchx import plugins

reg = plugins.registry()
scheds = reg.get(plugins.PluginType.SCHEDULER)
print(reg)

Discovery channels are selected via plugin_sources, a PluginSource bitmask. Defaults to all channels enabled; pass PluginSource.NONE for an empty registry or any combination of NAMESPACE_PKG / ENTRYPOINT to enable a subset.

Parameters:

plugin_sources – Bitmask of discovery channels to enable. Defaults to NAMESPACE_PKG | ENTRYPOINT.

get(plugin_type: PluginType) dict[str, Callable[..., Any]][source]

Discover plugins for plugin_type. Cached after first call.

Returns a dict mapping plugin names to their factory callables. Returns an empty dict when no plugins are found (never None).

info() dict[torchx.plugins._registry.PluginType, dict[str, Callable[..., Any]]][source]
info(plugin_type: PluginType) dict[str, Callable[..., Any]]

Return discovered plugins.

When called with no arguments, triggers discovery for all plugin groups and returns a defensive copy of the full cache keyed by PluginType.

When called with a plugin_type, returns a defensive copy of the plugins dict for that single group.

Example:

>>> from torchx import plugins
>>> all_plugins = plugins.registry().info()  
>>> scheds = plugins.registry().info(plugins.PluginType.SCHEDULER)
class torchx.plugins.PluginType(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]

Type of TorchX plugin.

Values are the entry-point group names (e.g. "torchx.schedulers").

This enum is used internally to tag decorated functions and to filter plugins during discovery. It is part of the public API only to support advanced subclassing of register.

Registration

class torchx.plugins.register(type: PluginType, name: str | None = None)[source]

Decorator that tags a function as a TorchX plugin.

Sets _plugin_type and _plugin_name attributes on the decorated function. The discovery scanner (PluginRegistry) imports each submodule under torchx_plugins.* and collects any callable with a _plugin_type attribute.

Usage:

from torchx.plugins import register

@register.scheduler()
def my_scheduler(session_name: str, **kwargs) -> Scheduler:
    ...

@register.scheduler(name="custom_name")
def create_custom(session_name: str, **kwargs) -> Scheduler:
    ...

Each PluginType has a corresponding classmethod: scheduler(), tracker(), and named_resource().

The explicit constructor register(PluginType.SCHEDULER, name=...) is still supported for advanced use-cases.

Fractional resource helpers

powers_of_two_gpus() and halve_mem_down_to() are module-level functions that plugin authors can use directly:

@register.named_resource(fractionals=powers_of_two_gpus)
def my_gpu(fractional: float = 1.0) -> Resource:
    ...

@register.named_resource(fractionals=halve_mem_down_to(minGiB=16))
def my_cpu(fractional: float = 1.0) -> Resource:
    ...
param type:

The PluginType declaring what kind of plugin this function provides.

param name:

Plugin name. Defaults to the decorated function’s __name__.

classmethod named_resource(name: str | None = None, aliases: list[str] | None = None, fractionals: Optional[Union[Callable[[...], dict[float, str]], dict[float, str]]] = None) _register_named_resource[source]

Register a named resource factory.

Parameters:
  • name – Resource name. Defaults to the decorated function’s __name__.

  • aliases – Additional names that point to the same factory.

  • fractionals – Either a callable (Resource) {fraction: suffix} or a literal dict. When provided, the decorated function must accept a fractional: float parameter.

classmethod scheduler(name: str | None = None) register[source]

Register a Scheduler factory.

classmethod tracker(name: str | None = None) register[source]

Register a tracker factory.

Fractional Helpers

torchx.plugins.powers_of_two_gpus(resource: Any) dict[float, str][source]

Return fractional specs for every power-of-two GPU slice of resource.

For example, an 8-GPU resource produces:

{1.0: "8", 0.5: "4", 0.25: "2", 0.125: "1"}

When passed as the fractionals argument of register.named_resource(), this auto-generates and registers all power-of-two fractional variants (e.g. my_gpu_8, my_gpu_4, …).

Parameters:

resource – The base (whole-host) resource. Must have gpu > 0 and gpu must be a power of two.

Raises:

ValueError – if resource.gpu is zero or not a power of two.

torchx.plugins.halve_mem_down_to(*, minGiB: int) Callable[[...], dict[float, str]][source]

Generate fractional suffixes by halving memory in GiB.

Returns a callable that produces a geometric series with r=1/2 starting from the resource’s total memory in GiB down to minGiB. Each entry maps a fractional float to the corresponding memory suffix string.

Example — 64 GiB host with minGiB=8:

{1.0: "64", 0.5: "32", 0.25: "16", 0.125: "8"}

Usage:

@register.named_resource(fractionals=halve_mem_down_to(minGiB=16))
def t1(fractional: float = WHOLE) -> Resource: ...
Parameters:

minGiB – Stop generating fractionals when memory drops below this threshold in GiB. Must be >= the odd part of memGiB (i.e. memGiB with all factors of 2 removed), otherwise halving would produce non-integer GiB values.

Raises:

ValueError – if resource.memMB is zero, not GiB-aligned, or minGiB is below the odd part of memGiB.

Constants

torchx.plugins.WHOLE = 1.0

Convert a string or number to a floating point number, if possible.

torchx.plugins.HALF = 0.5

Convert a string or number to a floating point number, if possible.

torchx.plugins.QUARTER = 0.25

Convert a string or number to a floating point number, if possible.

torchx.plugins.EIGHTH = 0.125

Convert a string or number to a floating point number, if possible.

torchx.plugins.SIXTEENTH = 0.0625

Convert a string or number to a floating point number, if possible.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources