Rate this Page

Source code for torchcomms

# Copyright (c) Meta Platforms, Inc. and affiliates.

# pyre-strict
import ctypes
import os
import sys
from datetime import timedelta
from importlib.metadata import entry_points

# We need to load this upfront since libtorchcomms depend on libtorch
import torch  # noqa: F401
from torchcomms.functional import is_torch_compile_supported_and_enabled

torch_compile_supported_and_enabled: bool = is_torch_compile_supported_and_enabled()

if torch_compile_supported_and_enabled:
    from torch._opaque_base import OpaqueBaseMeta

    # make the metaclass available to the pybind module
    sys.modules["torchcomms._opaque_meta"] = type(
        "module", (), {"OpaqueBaseMeta": OpaqueBaseMeta}
    )()

    # to support opaque registration for time delta.
    class Timeout(timedelta, metaclass=OpaqueBaseMeta):
        pass
else:
    # When compile support is disabled, define Timeout without the metaclass
[docs] class Timeout(timedelta): pass
def _load_libtorchcomms() -> None: libtorchcomms_path = os.path.join(os.path.dirname(__file__), "libtorchcomms.so") # OSS build, buck native linking links everything together so this is not needed if os.path.exists(libtorchcomms_path): # load this using RTLD_LOCAL so that we don't pollute the global namespace # We need to load this upfront since _comms and _comms_* depend on it # and won't be able to find it themselves. ctypes.CDLL(libtorchcomms_path, mode=ctypes.RTLD_LOCAL) _load_libtorchcomms() from torchcomms._comms import * # noqa: E402, F401, F403 import torchcomms.hooks as hooks # noqa: E402, F401 import torchcomms.objcol as objcol # noqa: E402, F401, F403 if torch_compile_supported_and_enabled: # Import collectives first to ensure all operations are registered # This must happen before patch_torchcomm() so that window operations # and other collectives are registered and can be patched from torchcomms.functional import collectives # noqa: F401 # The documentation uses __all__ to determine what is documented and in what # order. __all__ = [ # noqa: F405 "new_comm", "TorchComm", "ReduceOp", "TorchWork", "Timeout", "BatchP2POptions", "BatchSendRecv", "P2POp", "CommOptions", "TorchCommWindow", ] for name in __all__: cls = globals()[name] cls.__module__ = "torchcomms" def _load_backend(backend: str) -> None: """Used to load backends lazily from C++ If a backend is already loaded, this function is a no-op. """ found = entry_points(group="torchcomms.backends", name=backend) if not found: raise ModuleNotFoundError( f"failed to find backend {backend}, is it registered via entry_points.txt?" ) wheel = next(iter(found)) wheel.load()