Hooks#
torchcomms provides a hooks mechanism that allows you to intercept and monitor collective operations. Hooks are useful for debugging, profiling, and implementing custom monitoring solutions.
Hook Types#
torchcomms supports three types of hooks:
- Pre-hooks
Called before each collective operation starts. Receives operation metadata including operation name, tensors, and a unique operation ID.
- Post-hooks
Called after each collective operation completes. Receives the operation name, work object, and the same operation ID for correlation.
- Abort-hooks
Called before the process aborts due to a collective timeout or failure. Useful for capturing debug information before termination.
Custom Hooks#
You can register custom hook callbacks directly on a communicator using the
register_pre_hook, register_post_hook, and register_abort_hook
methods. This allows you to implement custom monitoring, logging, or debugging
logic.
Registering Hooks#
import torch
import torchcomms
from torchcomms._comms import OpName, PreHookArgs, PostHookArgs
# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)
# Define hook callbacks
def my_pre_hook(args: PreHookArgs) -> None:
print(f"Starting operation: {args.name} (op_id={args.op_id})")
def my_post_hook(args: PostHookArgs) -> None:
print(f"Completed operation: {args.name} (op_id={args.op_id})")
def my_abort_hook() -> None:
print("Process is about to abort, saving debug info...")
# Register hooks - each returns a RemovableHandle
pre_handle = comm.register_pre_hook(my_pre_hook)
post_handle = comm.register_post_hook(my_post_hook)
abort_handle = comm.register_abort_hook(my_abort_hook)
# Run collective operations - hooks will be called
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)
# Unregister hooks when done
pre_handle.remove()
post_handle.remove()
abort_handle.remove()
comm.finalize()
Thread Safety Note#
Hooks are not thread-safe and must not be modified (registered or removed) while a collective operation is in progress. Register hooks before starting collective operations and remove them after all operations have completed.
FlightRecorderHook#
The FlightRecorderHook is a built-in hook implementation that tracks all
collective operations for debugging purposes. It records operation metadata,
timing information, and completion status in a ring buffer.
The output format matches the OSS FlightRecorder format from PyTorch’s
distributed module, so traces can be analyzed using the same fr_trace
analysis tools.
FlightRecorderHook Example#
import torch
import torchcomms
from torchcomms.hooks import FlightRecorderHook
# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)
# Create and register a flight recorder hook
recorder = FlightRecorderHook(max_entries=1024)
recorder.register_with_comm(comm)
# Run some collective operations
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)
comm.barrier(async_op=False)
# Dump the recorded trace as JSON
json_trace = recorder.dump_json()
print(json_trace)
# Optionally dump to a file
recorder.dump_file(rank=comm.get_rank())
# Unregister when done
recorder.unregister()
# Finalize the communicator
comm.finalize()
Environment Variables#
The FlightRecorderHook uses the following environment variable:
TORCHCOMM_FR_DUMP_TEMP_FILEControls the output location for
dump_file(). Files are written as<prefix><rank>where the prefix is the value of this variable.
NanCheckHook#
The NanCheckHook detects NaN values in tensors before collective
operations run, using the c10d::check_for_nan dispatched op which works
on both CPU and CUDA tensors. It handles both single-tensor operations
(e.g., all_reduce) and multi-tensor operations (e.g., all_to_all).
This is particularly useful for debugging numerical corruption that may
originate on a single host. Without this hook, a NaN produced by one rank’s
local computation silently enters a collective (such as all_reduce) and
pollutes the results on every other rank, making the root cause extremely
difficult to track down. By checking tensors before the collective runs,
NanCheckHook identifies the offending rank and operation immediately,
before the corruption has a chance to propagate across the job.
NanCheckHook Example#
import torch
import torchcomms
from torchcomms.hooks import NanCheckHook
# Create a communicator
device = torch.device("cuda:0")
comm = torchcomms.new_comm("ncclx", device)
# Create and register the NaN check hook
nan_check = NanCheckHook()
nan_check.register_with_comm(comm)
# This will raise RuntimeError if tensor contains NaN
tensor = torch.ones(10, device=device)
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, async_op=False)
# Unregister when done
nan_check.unregister()
comm.finalize()
API Reference#
Hooks Module#
TorchComm hooks module.
This module serves as a namespace for TorchComm hook types.
- class torchcomms.hooks.FlightRecorderHook#
Bases:
pybind11_objectFlightRecorderHook tracks all collective operations in flight for TorchComm communicators.
The output format matches the OSS FlightRecorder format from PyTorch’s distributed module, so traces can be analyzed using the same fr_trace analysis tools.
Example
>>> from torchcomms.hooks import fr >>> comm = torchcomms.new_comm("nccl", device, "world") >>> recorder = fr.FlightRecorderHook(max_entries=1024) >>> recorder.register_with_comm(comm) >>> # ... run some collectives ... >>> json_trace = recorder.dump_json()
For testing, use
isolated=Trueto create a separate FlightRecorder instance that is not shared with other hooks:>>> recorder = fr.FlightRecorderHook(max_entries=100, isolated=True)
- __init__(self: torchcomms.hooks.fr.FlightRecorderHook, max_entries: SupportsInt = 2048, isolated: bool = False) None#
Create a FlightRecorderHook with specified buffer size.
- Parameters:
max_entries – Maximum number of entries in the ring buffer. Older entries are overwritten when full.
isolated – If True, creates an isolated FlightRecorder instance for this hook instead of using the global singleton.
- dump_file(self: torchcomms.hooks.fr.FlightRecorderHook, rank: SupportsInt, include_completed: bool = True) None#
Dump the flight recorder trace and write it to a file.
The output location is controlled by the TORCHCOMM_FR_DUMP_TEMP_FILE environment variable. Files are written as <prefix><rank>.
- Parameters:
rank – The rank to use for the file name.
include_completed – If False, only dump entries that are not completed.
- dump_json(self: torchcomms.hooks.fr.FlightRecorderHook, include_completed: bool = True) str#
Dump all entries as a JSON string in the OSS FlightRecorder format.
This format is compatible with the fr_trace analyzer tools from torch.distributed.flight_recorder.
- Parameters:
include_completed – If False, only return entries that are not completed.
- Returns:
A JSON string containing the flight recorder trace.
- is_enabled(self: torchcomms.hooks.fr.FlightRecorderHook) bool#
Check if the hook has registered communicators.
- register_with_comm(self: torchcomms.hooks.fr.FlightRecorderHook, comm: torchcomms.TorchComm) None#
Register this hook with a TorchComm communicator.
- Parameters:
comm – The communicator to register with.
- reset(self: torchcomms.hooks.fr.FlightRecorderHook) None#
Clear all entries and reset sequence counters.
- size(self: torchcomms.hooks.fr.FlightRecorderHook) int#
Get the current number of entries.
- unregister(self: torchcomms.hooks.fr.FlightRecorderHook) None#
Unregister this hook from all communicators.
- class torchcomms.hooks.NanCheckHook(check_inputs: bool = True, check_outputs: bool = False)#
Bases:
objectHook that checks for NaN values in tensors before collective operations.
Registers a pre-hook on communicators that inspects input and/or output tensors for NaN values using the dispatched
c10d::check_for_nanop (works on both CPU and CUDA). If detected, raises aRuntimeErrorwith context about which operation and communicator triggered the check.- Parameters:
check_inputs – Whether to check input tensors. Default: True.
check_outputs – Whether to check output tensors. Default: False.