API Reference#
- torchcomms.new_comm(backend: str, device: torch.device, name: str, abort_process_on_timeout_or_error: bool | None = None, timeout: datetime.timedelta | None = None, high_priority_stream: bool | None = None, store: torch.distributed.distributed_c10d.Store | None = None, hints: collections.abc.Mapping[str, str] | None = None) torchcomms.TorchComm #
Create a new communicator.
This requires all ranks that will be part of the commmunicator call this function simultaneously.
Ranks and world size will be derived from environment variables set by launchers such as torchrun (i.e.
RANK
,WORLD_SIZE
).Backends typically use a store to initialize which can either be provided or automatically instantiated from environment variables such as
MASTER_ADDR
andMASTER_PORT
. Backends are not required to use the store to initialize if more performant options are available.Subcommunicators can be instantiated by using the
split
method.- Parameters:
backend (str) – The backend to use for the communicator.
device (torch.device) – The device to use for the communicator.
name (str) – The name of the communicator. This must be unique within the process.
abort_process_on_timeout_or_error (bool) – Whether to abort process on timeout or error.
timeout (timedelta) – Timeout for initialization.
high_priority_stream (bool) – Whether to use high priority stream.
store (torch.distributed.Store) – Store used to initialize the communicator between processes.
hints (dict) – Dictionary of string hints for backend-specific options.
- class torchcomms.TorchComm#
Bases:
pybind11_object
- __init__(*args, **kwargs)#
- all_gather(self: torchcomms.TorchComm, tensor_list: collections.abc.Sequence[torch.Tensor], tensor: torch.Tensor, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Gather a tensor from all ranks in the communicator.
Output will be available on all ranks.
- Parameters:
tensor_list – the list of tensors to gather into
tensor – the input tensor to share
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- all_gather_single(self: torchcomms.TorchComm, output: torch.Tensor, input: torch.Tensor, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Gather a single tensor from all ranks in the communicator.
The output tensor must be of size (world_size * input.numel()).
- Parameters:
output – the output tensor to gather into
input – the input tensor to share
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- all_reduce(self: torchcomms.TorchComm, tensor: torch.Tensor, op: torchcomms.ReduceOp, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Reduce a tensor across all ranks in the communicator.
- Parameters:
tensor – the tensor to all-reduce
op – the reduction operation
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- all_to_all(self: torchcomms.TorchComm, output_tensor_list: collections.abc.Sequence[torch.Tensor], input_tensor_list: collections.abc.Sequence[torch.Tensor], async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Scatter the split list to all ranks.
- Parameters:
output_tensor_list – Output tensor list.
input_tensor_list – Input tensor list to scatter.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- all_to_all_single(self: torchcomms.TorchComm, output: torch.Tensor, input: torch.Tensor, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Split input tensor and then scatter the split list to all ranks.
Later the received tensors are concatenated and returned as a single output tensor.
The input and output tensor sizes must a multiple of world_size.
- Parameters:
output – Output tensor.
input – Input tensor to split and scatter.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- all_to_all_v_single(self: torchcomms.TorchComm, output: torch.Tensor, input: torch.Tensor, output_split_sizes: collections.abc.Sequence[SupportsInt], input_split_sizes: collections.abc.Sequence[SupportsInt], async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
All-to-all single tensor operation with variable split sizes.
- Parameters:
output – Output tensor.
input – Input tensor to split and scatter.
output_split_sizes – List of output split sizes.
input_split_sizes – List of input split sizes.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- barrier(self: torchcomms.TorchComm, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Block until all ranks have reached this call.
- Parameters:
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- batch_op_create(self: torchcomms.TorchComm) torchcomms.BatchSendRecv #
Create a batch operation object for batched P2P operations.
- broadcast(self: torchcomms.TorchComm, tensor: torch.Tensor, root: SupportsInt, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Broadcast tensor to all ranks in the communicator.
- Parameters:
tensor – the tensor to broadcast if root or receive into if not root
root – the root rank
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- finalize(self: torchcomms.TorchComm) None #
Finalize and free all resources. This must be called prior to destruction.
- gather(self: torchcomms.TorchComm, output_tensor_list: collections.abc.Sequence[torch.Tensor], input_tensor: torch.Tensor, root: SupportsInt, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Gather the input tensor from all ranks to the root.
- Parameters:
output_tensor_list – Output tensor list. Will be empty on non-root ranks.
input_tensor – Input tensor to gather.
root – The root rank.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- get_backend(self: torchcomms.TorchComm) str #
Get communicator backend name
- get_device(self: torchcomms.TorchComm) torch.device #
Get the communicator device
- get_name(self: torchcomms.TorchComm) str #
Get the name of the communicator
- get_options(self: torchcomms.TorchComm) torchcomms.CommOptions #
Get the communicator options
- get_rank(self: torchcomms.TorchComm) int #
Get the rank of this process
- get_size(self: torchcomms.TorchComm) int #
Get the world size
- recv(self: torchcomms.TorchComm, tensor: torch.Tensor, src: SupportsInt, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Receive tensor from source rank.
This will not run concurrently with other operations (including send/recv) on the same stream.
- Parameters:
tensor – the tensor to receive into
src – the source rank
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- reduce(self: torchcomms.TorchComm, tensor: torch.Tensor, root: SupportsInt, op: torchcomms.ReduceOp, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Reduce a tensor from all ranks to a single rank in the communicator.
Output will only be available on the root rank.
- Parameters:
tensor – the tensor to reduce
root – the root rank
op – the reduction operation
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- reduce_scatter(self: torchcomms.TorchComm, output: torch.Tensor, input_list: collections.abc.Sequence[torch.Tensor], op: torchcomms.ReduceOp, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Reduce, then scatter a list of tensors to all ranks.
- Parameters:
output – Output tensor.
input_list – List of tensors to reduce and scatter.
op – Reduction operation.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- reduce_scatter_single(self: torchcomms.TorchComm, output: torch.Tensor, input: torch.Tensor, op: torchcomms.ReduceOp, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Reduce, then scatter a single tensor to all ranks.
The input tensor must be of size (world_size * output.numel()).
- Parameters:
output – Output tensor.
input – Input tensor to reduce and scatter.
op – Reduction operation.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- scatter(self: torchcomms.TorchComm, output_tensor: torch.Tensor, input_tensor_list: collections.abc.Sequence[torch.Tensor], root: SupportsInt, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Scatter the split list to all ranks from the root.
- Parameters:
output_tensor – Output tensor.
input_tensor_list – Input tensor list to scatter.
root – The root rank.
async_op – Whether to perform the operation asynchronously.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
- send(self: torchcomms.TorchComm, tensor: torch.Tensor, dst: SupportsInt, async_op: bool, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchWork #
Send tensor to destination rank.
This will not run concurrently with other operations (including send/recv) on the same stream.
- Parameters:
tensor – the tensor to send
dst – the destination rank
async_op – whether to perform the operation asynchronously
hints – dictionary of string hints for backend-specific options
timeout – timeout for the operation
- split(self: torchcomms.TorchComm, ranks: collections.abc.Sequence[SupportsInt], name: str, hints: collections.abc.Mapping[str, str] | None = None, timeout: datetime.timedelta | None = None) torchcomms.TorchComm #
Split communicator into a subgroup.
- Parameters:
ranks – List of ranks to include in the new subgroup. If the list is empty, None will be returned. If the list is non-empty but does not include the current rank, an exception will be thrown.
name – Name for the new communicator.
hints – Dictionary of string hints for backend-specific options.
timeout – Timeout for the operation.
Returns: A new communicator for the subgroup, or None if the ranks list is empty.
Raises: RuntimeError if the ranks list is non-empty and the current rank is not included.
- unsafe_get_backend(self: torchcomms.TorchComm) torchcomms.TorchCommBackend #
Get communicator backend implementation.
WARNING: This is intended as an escape hatch for experimentation and development. Direct backend access provides no backwards compatibility guarantees. Users depending on unsafe_get_backend should expect their code to break as interfaces change.
- window_allocate(self: torchcomms.TorchComm, window_size: SupportsInt, cpu_buf: bool = False, signal_size: SupportsInt = 256) torchcomms.TorchCommWindow #
Allocate a shared Window with current TorchComm Communicator.
- Parameters:
window_size – The byte size of the window to be created.
cpu_buf – buffer is allocated on the CPU, or on the same device as the assigned communicator.
Returns: The window object.
- class torchcomms.ReduceOp#
Bases:
pybind11_object
Operation to perform during reduction.
- AVG = <torchcomms.ReduceOp object>#
- BAND = <torchcomms.ReduceOp object>#
- BOR = <torchcomms.ReduceOp object>#
- BXOR = <torchcomms.ReduceOp object>#
- MAX = <torchcomms.ReduceOp object>#
- MIN = <torchcomms.ReduceOp object>#
- static PREMUL_SUM(factor: torch.Tensor | SupportsFloat) torchcomms.ReduceOp #
- PRODUCT = <torchcomms.ReduceOp object>#
- SUM = <torchcomms.ReduceOp object>#
- __init__(self: torchcomms.ReduceOp, opType: torchcomms.RedOpType) None #
Create default ReduceOp
- property type#
Get the type of the operation
- class torchcomms.TorchWork#
Bases:
pybind11_object
TorchWork allows you to track whether an asynchronous operation has completed.
When async_op=True, the operation is enqueued on a background stream and a TorchWork object is returned. This work object must be waited on before using the output tensor.
This is intended to make it easier to write efficient code that can overlap communication with computation.
Example usage:
tensor = ... # run all_reduce on a background stream and return a TorchWork object work = torchcomms.all_reduce(tensor, ReduceOp.SUM, async_op=True) # Schedule some other work on the current stream a = b * 2 # block the current stream until the all_reduce is complete work.wait() # safely use the tensor after the collective completes tensor.sum() # block CPU until stream is complete torch.accelerator.current_stream().synchronize()
- __init__(*args, **kwargs)#
- is_completed(self: torchcomms.TorchWork) bool #
Check if the work is completed
- wait(self: torchcomms.TorchWork) None #
Block the current stream until the work is completed.
See https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-streams for more details.
- class torchcomms.BatchP2POptions#
Bases:
pybind11_object
Options for batched P2P operations.
- __init__(self: torchcomms.BatchP2POptions) None #
Create default BatchP2POptions
- property hints#
Hints dictionary
- property timeout#
Timeout
- class torchcomms.BatchSendRecv#
Bases:
pybind11_object
BatchSendRecv allows you to run multiple send/recv operations concurrently unlike the standard send/recv APIs which only allow you to have one inflight at a time.
- __init__(*args, **kwargs)#
- issue(self: torchcomms.BatchSendRecv, async_op: bool, options: torchcomms.BatchP2POptions = <torchcomms.BatchP2POptions object at 0x7f666d72edb0>) torchcomms.TorchWork #
Issues the batched operations
- property ops#
List of P2P operations
- recv(self: torchcomms.BatchSendRecv, tensor: torch.Tensor, src: SupportsInt) None #
Add recv operation to batch. Must be paired with a corresponding send operation on a different rank.
- Parameters:
tensor – the tensor to receive into
src – the source rank
- send(self: torchcomms.BatchSendRecv, tensor: torch.Tensor, dst: SupportsInt) None #
Add send operation to batch. Must be paired with a corresponding recv operation on a different rank.
- Parameters:
tensor – the tensor to send
dst – the destination rank
- class torchcomms.P2POp#
Bases:
pybind11_object
Represents a peer to peer operation as part of a batch.
- __init__(self: torchcomms.P2POp, type: torch::comms::BatchSendRecv::P2POp::OpType, tensor: torch.Tensor, peer: typing.SupportsInt) None #
Create P2POp.
- Parameters:
type – the type of the operations i.e. send/recv
tensor – the tensor to operate on
peer – the rank of the peer
- property peer#
Peer rank
- property tensor#
Tensor
- property type#
Operation type
- class torchcomms.CommOptions#
Bases:
pybind11_object
Options for communicator creation.
- __init__(self: torchcomms.CommOptions) None #
Create default CommOptions
- property abort_process_on_timeout_or_error#
Whether to abort process on timeout or error
- property hints#
Dictionary of string hints for backend-specific options
- property store#
Store for communication between processes
- property timeout#
Timeout for operations (milliseconds)
- class torchcomms.TorchCommWindow#
Bases:
pybind11_object
- __init__(*args, **kwargs)#
- get_size(self: torchcomms.TorchCommWindow) int #
Get the size of the window
- get_tensor(self: torchcomms.TorchCommWindow, rank: SupportsInt, sizes: Tuple[int, ...], dtype: torch.dtype, offset: SupportsInt) torch.Tensor #
get a tensor from remote window
- put(self: torchcomms.TorchCommWindow, tensor: torch.Tensor, dst_rank: SupportsInt, target_disp: SupportsInt, async_op: bool) torchcomms.TorchWork #
Put allows you to put a tensor into the previously allocated remote window.
- Parameters:
tensor – the tensor to put
dst_rank – the destination rank
target_disp – the target displacement
async_op – if this is true, the operation is asynced and will be enqueued on a background stream and a TorchWork object is returned.
Example usage:
tensor = ... # create a window window = torchcomms.create_window(window_size, cpu_buf) # put a tensor into the window work = window.put(tensor, dst_rank, target_disp, async_op=True) work.wait() # on the remote side, get the tensor from the window after waiting on the remote signal tensor = window.get_tensor(rank, tensor_sizes, tensor_dtype, offset) # safely use the tensor after the collective completes tensor.sum()
- signal(self: torchcomms.TorchCommWindow, signal_disp: SupportsInt, signal_val: SupportsInt, dst_rank: SupportsInt, async_op: bool) torchcomms.TorchWork #
Atomic signal to notify remote peer of a change in state.
- Parameters:
signal_disp – the displacement in signal bugger.
signal_val – the signal value to set for signalling.
dst_rank – the destination rank.
async_op – if this is true, the operation is asynced.
- wait_signal(self: torchcomms.TorchCommWindow, signal_disp: SupportsInt, signal_val: SupportsInt, cmp_op: torchcomms.SignalCmpOp, async_op: bool) torchcomms.TorchWork #
wait for a signal from remote peer
- torchcomms.device_mesh.init_device_mesh(mesh_dim_comms: tuple[TorchComm, ...], mesh_dim_names: tuple[str, ...], _global_comm: TorchComm | None = None) DeviceMesh [source]#
Initializes a DeviceMesh from the list of provided TorchComm instances.
See DeviceMesh for more details.
- torchcomms.objcol.all_gather_object(comm: TorchComm, object_list: list[Any], obj: object, timeout: timedelta | None = None, weights_only: bool = True) None [source]#
Gathers picklable objects from the whole comm into a list.
Similar to
all_gather()
, but Python objects can be passed in. Note that the object must be picklable in order to be gathered.- Parameters:
comm – The comm to work on.
object_list (list[object]) – Output list. It should be correctly sized as the size of the comm for this collective and will contain the output.
obj (object) – Pickable Python object to be broadcast from current process.
timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.weights_only (bool, optional) – If
True
, only safe objects such as weights are allowed to be deserialized. https://docs.pytorch.org/docs/stable/notes/serialization.html#weights-only
- Returns:
None. If the calling rank is part of this comm, the output of the collective will be populated into the input
object_list
. If the calling rank is not part of the comm, the passed inobject_list
will be unmodified.
Note
Note that this API differs slightly from the
all_gather()
collective since it does not provide anasync_op
handle and thus will be a blocking call.Note
For NCCL-based processed comms, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
all_gather_object()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
all_gather_object()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usingall_gather()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] >>> objcol.all_gather_object(comm, output, gather_objects[comm.get_rank()]) >>> output ['foo', 12, {1: 2}]
- torchcomms.objcol.gather_object(comm: TorchComm, obj: object, root: int, object_gather_list: list[Any] | None = None, timeout: timedelta | None = None, weights_only: bool = True) None [source]#
Gathers picklable objects from the whole comm in a single process.
Similar to
gather()
, but Python objects can be passed in. Note that the object must be picklable in order to be gathered.- Parameters:
comm – The comm to work on.
obj (object) – Input object. Must be picklable.
object_gather_list (list[object]) – Output list. On the
root
rank, it should be correctly sized as the size of the comm for this collective and will contain the output. Must beNone
on non-root ranks. (default isNone
)root (int, optional) – Destination rank on
comm
. Invalid to specify bothroot
androot
timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.weights_only (bool, optional) – If
True
, only safe objects such as weights are allowed to be deserialized. https://docs.pytorch.org/docs/stable/notes/serialization.html#weights-only
- Returns:
None. On the
root
rank,object_gather_list
will contain the output of the collective.
Note
Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call.
Note
For NCCL-based processed comms, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
gather_object()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
gather_object()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usinggather()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> # Assumes world_size of 3. >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object >>> output = [None for _ in gather_objects] >>> objcol.gather_object( ... comm, ... gather_objects[comm.get_rank()], ... output, ... root=0 ... ) >>> # On rank 0 >>> output ['foo', 12, {1: 2}]
- torchcomms.objcol.send_object_list(comm: TorchComm, object_list: list[Any], dst: int, timeout: timedelta | None = None) None [source]#
Sends picklable objects in
object_list
synchronously.Similar to
send()
, but Python objects can be passed in. Note that all objects inobject_list
must be picklable in order to be sent.- Parameters:
comm – The comm to work on.
object_list (List[object]) – List of input objects to sent. Each object must be picklable. Receiver must provide lists of equal sizes.
dst (int) – Destination rank to send
object_list
to.timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.
- Returns:
None
.
Note
For NCCL-based comms, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
send_object_list()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
send_object_list()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usingsend()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> # Assumes backend is not NCCL >>> if comm.get_rank() == 0: >>> # Assumes world_size of 2. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> objcol.send_object_list(comm, objects, dst=1) >>> else: >>> objects = [None, None, None] >>> objcol.recv_object_list(comm, objects, src=0) >>> objects ['foo', 12, {1: 2}]
- torchcomms.objcol.recv_object_list(comm: TorchComm, object_list: list[Any], src: int, timeout: timedelta | None = None, weights_only: bool = True) None [source]#
Receives picklable objects in
object_list
synchronously.Similar to
recv()
, but can receive Python objects.- Parameters:
comm – The comm to work on.
object_list (List[object]) – List of objects to receive into. Must provide a list of sizes equal to the size of the list being sent.
src (int) – Source rank from which to recv
object_list
.timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.weights_only (bool, optional) – If
True
, only safe objects such as weights are allowed to be deserialized. https://docs.pytorch.org/docs/stable/notes/serialization.html#weights-only
Returns: None
Note
For NCCL-based comms, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
recv_object_list()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
recv_object_list()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usingrecv()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> # Assumes backend is not NCCL >>> if comm.get_rank() == 0: >>> # Assumes world_size of 2. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> objcol.send_object_list(comm, objects, dst=1) >>> else: >>> objects = [None, None, None] >>> objcol.recv_object_list(comm, objects, src=0) >>> objects ['foo', 12, {1: 2}]
- torchcomms.objcol.broadcast_object_list(comm: TorchComm, object_list: list[Any], root: int, timeout: timedelta | None = None, weights_only: bool = True)[source]#
Broadcasts picklable objects in
object_list
to the whole comm.Similar to
broadcast()
, but Python objects can be passed in. Note that all objects inobject_list
must be picklable in order to be broadcasted.- Parameters:
comm – The comm to work on.
object_list (List[object]) – List of input objects to broadcast. Each object must be picklable. Only objects on the
src
rank will be broadcast, but each rank must provide lists of equal sizes.root (int) – Source rank from which to broadcast
object_list
.timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.weights_only (bool, optional) – If
True
, only safe objects such as weights are allowed to be deserialized. https://docs.pytorch.org/docs/stable/notes/serialization.html#weights-only
- Returns:
None
. If rank is part of the comm,object_list
will contain the broadcasted objects fromsrc
rank.
Note
For NCCL-based comms, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.Note
Note that this API differs slightly from the
broadcast()
collective since it does not provide anasync_op
handle and thus will be a blocking call.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
broadcast_object_list()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
broadcast_object_list()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usingbroadcast()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> if comm.get_rank() == 0: >>> # Assumes world_size of 3. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> else: >>> objects = [None, None, None] >>> # Assumes backend is not NCCL >>> objcol.broadcast_object_list(comm, objects, src=0, device=device) >>> objects ['foo', 12, {1: 2}]
- torchcomms.objcol.scatter_object_list(comm: TorchComm, root: int, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any] | None = None, timeout: timedelta | None = None, weights_only: bool = True) None [source]#
Scatters picklable objects in
scatter_object_input_list
to the whole comm.Similar to
scatter()
, but Python objects can be passed in. On each rank, the scattered object will be stored as the first element ofscatter_object_output_list
. Note that all objects inscatter_object_input_list
must be picklable in order to be scattered.- Parameters:
comm – The comm to work on.
scatter_object_output_list (List[object]) – Non-empty list whose first element will store the object scattered to this rank.
scatter_object_input_list (List[object], optional) – List of input objects to scatter. Each object must be picklable. Only objects on the
root
rank will be scattered, and the argument can beNone
for non-root ranks.root (int) – Source rank from which to scatter
scatter_object_input_list
.timeout – (timedelta, optional): Timeout for collective operations. If
None
, will use the default timeout for the backend.weights_only (bool, optional) – If
True
, only safe objects such as weights are allowed to be deserialized. https://docs.pytorch.org/docs/stable/notes/serialization.html#weights-only
- Returns:
None
. If rank is part of the comm,scatter_object_output_list
will have its first element set to the scattered object for this rank.
Note
Note that this API differs slightly from the scatter collective since it does not provide an
async_op
handle and thus will be a blocking call.Warning
Object collectives have a number of serious performance and scalability limitations. See object_collectives for details.
Warning
scatter_object_list()
usespickle
module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.Warning
Calling
scatter_object_list()
with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider usingscatter()
instead.- Example::
>>> # xdoctest: +SKIP("need comm init") >>> # Note: comm initialization omitted on each rank. >>> from torchcomms import objcol >>> if comm.get_rank() == 0: >>> # Assumes world_size of 3. >>> objects = ["foo", 12, {1: 2}] # any picklable object >>> else: >>> # Can be any list on non-root ranks, elements are not used. >>> objects = [None, None, None] >>> output_list = [None] >>> objcol.scatter_object_list(comm, output_list, objects, root=0) >>> # Rank i gets objects[i]. For example, on rank 2: >>> output_list [{1: 2}]