Rate this Page

Source code for torchcomms

# Copyright (c) Meta Platforms, Inc. and affiliates.

# pyre-strict
# patternlint-disable fbcode-nonempty-init-py
import ctypes
import os
import sys
from datetime import timedelta
from importlib.metadata import entry_points

# We need to load this upfront since libtorchcomms depend on libtorch
import torch  # noqa: F401
from torchcomms.functional import (
    is_torch_compile_supported,
    is_torch_compile_supported_and_enabled,
)

if is_torch_compile_supported():
    from torch._opaque_base import OpaqueBaseMeta

    # make the metaclass available to the pybind module
    sys.modules["torchcomms._opaque_meta"] = type(
        "module", (), {"OpaqueBaseMeta": OpaqueBaseMeta}
    )()

    # to support opaque registration for time delta.
    class Timeout(timedelta, metaclass=OpaqueBaseMeta):
        pass
else:
    # When compile support is disabled, define Timeout without the metaclass
[docs] class Timeout(timedelta): pass
def _load_libtorchcomms() -> None: libtorchcomms_path = os.path.join(os.path.dirname(__file__), "libtorchcomms.so") # OSS build, buck native linking links everything together so this is not needed if os.path.exists(libtorchcomms_path): # load this using RTLD_LOCAL so that we don't pollute the global namespace # We need to load this upfront since _comms and _comms_* depend on it # and won't be able to find it themselves. ctypes.CDLL(libtorchcomms_path, mode=ctypes.RTLD_LOCAL) _load_libtorchcomms() from torchcomms._comms import * # noqa: E402, F401, F403 import torchcomms.hooks as hooks # noqa: E402, F401 import torchcomms.objcol as objcol # noqa: E402, F401, F403 if is_torch_compile_supported_and_enabled(): # Import collectives first to ensure all operations are registered # This must happen before patch_torchcomm() so that window operations # and other collectives are registered and can be patched from torchcomms.functional import collectives # noqa: F401 # The documentation uses __all__ to determine what is documented and in what # order. __all__ = [ # noqa: F405 "new_comm", "TorchComm", "ReduceOp", "TorchWork", "Timeout", "BatchP2POptions", "BatchSendRecv", "P2POp", "CommOptions", "TorchCommWindow", "register_backend", "TorchCommBackend", ] for name in __all__: cls = globals()[name] cls.__module__ = "torchcomms" def _load_backend(backend: str) -> None: """Used to load backends lazily from C++ C++ calls this only when the backend is not already registered via register_backend. """ found = entry_points(group="torchcomms.backends", name=backend) if not found: raise ModuleNotFoundError( f"failed to find backend {backend}, is it registered via entry_points.txt?" ) wheel = next(iter(found)) wheel.load() def is_backend_built(backend: str) -> bool: """True if the wheel was built with this backend's extension.""" return bool(entry_points(group="torchcomms.backends", name=backend)) def built_backends() -> list[str]: """Names of all backends the wheel was built with.""" return [ep.name for ep in entry_points(group="torchcomms.backends")]