Rate this Page

Hooks#

torchcomms provides a hooks mechanism that allows you to intercept and monitor collective operations. Hooks are useful for debugging, profiling, and implementing custom monitoring solutions.

Hook Types#

torchcomms supports three types of hooks:

Pre-hooks

Called before each collective operation starts. Receives operation metadata including operation name, tensors, and a unique operation ID.

Post-hooks

Called after each collective operation completes. Receives the operation name, work object, and the same operation ID for correlation.

Abort-hooks

Called before the process aborts due to a collective timeout or failure. Useful for capturing debug information before termination.

Custom Hooks#

You can register custom hook callbacks directly on a communicator using the register_pre_hook, register_post_hook, and register_abort_hook methods. This allows you to implement custom monitoring, logging, or debugging logic.

Registering Hooks#

import torch
import torchcomms
from torchcomms._comms import OpName, PreHookArgs, PostHookArgs

# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)

# Define hook callbacks
def my_pre_hook(args: PreHookArgs) -> None:
    print(f"Starting operation: {args.name} (op_id={args.op_id})")

def my_post_hook(args: PostHookArgs) -> None:
    print(f"Completed operation: {args.name} (op_id={args.op_id})")

def my_abort_hook() -> None:
    print("Process is about to abort, saving debug info...")

# Register hooks - each returns a RemovableHandle
pre_handle = comm.register_pre_hook(my_pre_hook)
post_handle = comm.register_post_hook(my_post_hook)
abort_handle = comm.register_abort_hook(my_abort_hook)

# Run collective operations - hooks will be called
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)

# Unregister hooks when done
pre_handle.remove()
post_handle.remove()
abort_handle.remove()

comm.finalize()

Thread Safety Note#

Hooks are not thread-safe and must not be modified (registered or removed) while a collective operation is in progress. Register hooks before starting collective operations and remove them after all operations have completed.

FlightRecorderHook#

The FlightRecorderHook is a built-in hook implementation that tracks all collective operations for debugging purposes. It records operation metadata, timing information, and completion status in a ring buffer.

The output format matches the OSS FlightRecorder format from PyTorch’s distributed module, so traces can be analyzed using the same fr_trace analysis tools.

FlightRecorderHook Example#

import torch
import torchcomms
from torchcomms.hooks import FlightRecorderHook

# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)

# Create and register a flight recorder hook
recorder = FlightRecorderHook(max_entries=1024)
recorder.register_with_comm(comm)

# Run some collective operations
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)
comm.barrier(async_op=False)

# Dump the recorded trace as JSON
json_trace = recorder.dump_json()
print(json_trace)

# Optionally dump to a file
recorder.dump_file(rank=comm.get_rank())

# Unregister when done
recorder.unregister()

# Finalize the communicator
comm.finalize()

Environment Variables#

The FlightRecorderHook uses the following environment variable:

TORCHCOMM_FR_DUMP_TEMP_FILE

Controls the output location for dump_file(). Files are written as <prefix><rank> where the prefix is the value of this variable.

NanCheckHook#

The NanCheckHook detects NaN values in tensors before collective operations run, using the c10d::check_for_nan dispatched op which works on both CPU and CUDA tensors. It handles both single-tensor operations (e.g., all_reduce) and multi-tensor operations (e.g., all_to_all).

This is particularly useful for debugging numerical corruption that may originate on a single host. Without this hook, a NaN produced by one rank’s local computation silently enters a collective (such as all_reduce) and pollutes the results on every other rank, making the root cause extremely difficult to track down. By checking tensors before the collective runs, NanCheckHook identifies the offending rank and operation immediately, before the corruption has a chance to propagate across the job.

NanCheckHook Example#

import torch
import torchcomms
from torchcomms.hooks import NanCheckHook

# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)

# Create and register the NaN check hook
nan_check = NanCheckHook()
nan_check.register_with_comm(comm)

# This will raise RuntimeError if tensor contains NaN
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)

# Unregister when done
nan_check.unregister()

comm.finalize()

API Reference#

Hooks Module#

TorchComm hooks module.

This module serves as a namespace for TorchComm hook types.

class torchcomms.hooks.FlightRecorderHook#

Bases: pybind11_object

FlightRecorderHook tracks all collective operations in flight for TorchComm communicators.

The output format matches the OSS FlightRecorder format from PyTorch’s distributed module, so traces can be analyzed using the same fr_trace analysis tools.

Example

>>> from torchcomms.hooks import fr
>>> comm = torchcomms.new_comm("nccl", device, "world")
>>> recorder = fr.FlightRecorderHook(max_entries=1024)
>>> recorder.register_with_comm(comm)
>>> # ... run some collectives ...
>>> json_trace = recorder.dump_json()

For testing, use isolated=True to create a separate FlightRecorder instance that is not shared with other hooks:

>>> recorder = fr.FlightRecorderHook(max_entries=100, isolated=True)
__init__(self: torchcomms.hooks.fr.FlightRecorderHook, max_entries: SupportsInt = 2048, isolated: bool = False) None#

Create a FlightRecorderHook with specified buffer size.

Parameters:
  • max_entries – Maximum number of entries in the ring buffer. Older entries are overwritten when full.

  • isolated – If True, creates an isolated FlightRecorder instance for this hook instead of using the global singleton.

dump_file(self: torchcomms.hooks.fr.FlightRecorderHook, rank: SupportsInt, include_completed: bool = True) None#

Dump the flight recorder trace and write it to a file.

The output location is controlled by the TORCHCOMM_FR_DUMP_TEMP_FILE environment variable. Files are written as <prefix><rank>.

Parameters:
  • rank – The rank to use for the file name.

  • include_completed – If False, only dump entries that are not completed.

dump_json(self: torchcomms.hooks.fr.FlightRecorderHook, include_completed: bool = True) str#

Dump all entries as a JSON string in the OSS FlightRecorder format.

This format is compatible with the fr_trace analyzer tools from torch.distributed.flight_recorder.

Parameters:

include_completed – If False, only return entries that are not completed.

Returns:

A JSON string containing the flight recorder trace.

is_enabled(self: torchcomms.hooks.fr.FlightRecorderHook) bool#

Check if the hook has registered communicators.

register_with_comm(self: torchcomms.hooks.fr.FlightRecorderHook, comm: torchcomms.TorchComm) None#

Register this hook with a TorchComm communicator.

Parameters:

comm – The communicator to register with.

reset(self: torchcomms.hooks.fr.FlightRecorderHook) None#

Clear all entries and reset sequence counters.

size(self: torchcomms.hooks.fr.FlightRecorderHook) int#

Get the current number of entries.

unregister(self: torchcomms.hooks.fr.FlightRecorderHook) None#

Unregister this hook from all communicators.

class torchcomms.hooks.NanCheckHook(check_inputs: bool = True, check_outputs: bool = False)#

Bases: object

Hook that checks for NaN values in tensors before collective operations.

Registers a pre-hook on communicators that inspects input and/or output tensors for NaN values using the dispatched c10d::check_for_nan op (works on both CPU and CUDA). If detected, raises a RuntimeError with context about which operation and communicator triggered the check.

Parameters:
  • check_inputs – Whether to check input tensors. Default: True.

  • check_outputs – Whether to check output tensors. Default: False.

__init__(check_inputs: bool = True, check_outputs: bool = False) None[source]#
is_enabled() bool[source]#

Return whether any communicators are registered.

register_with_comm(comm: Any) None[source]#

Register the NaN check hook with a communicator.

Parameters:

comm – A TorchComm communicator instance.

unregister() None[source]#

Remove all registered hooks.