Note
Go to the end to download the full example code
Distributed Tensors in Monarch#
import monarch
import torch
import torch.nn as nn
from monarch.actor import this_host
torch.set_default_device("cuda")
Meshes#
All computation is done on a ‘mesh’ of devices. Here we create a mesh composed of the machine running the notebook:
mesh = this_host().spawn_procs({"gpu": 8})
print(mesh.to_table())
Without a mesh active, torch runs locally.
torch.rand(3, 4)
Once active, torch runs on every device in the mesh.
with mesh.activate():
t = torch.rand(3, 4, device="cuda")
t
Inspect moves rank0’s copy of t to the notebook for debugging.
monarch.inspect(t)
monarch.show(t)
Providing coordinates lets us inspect other ranks copies.
monarch.show(t, gpu=1)
Tensor Commands#
Any command done on the controller, such as multiplying these tensors, performs that action to all of the tensors in the collection.
with mesh.activate():
obj = t @ t.T
monarch.show(obj)
If a command fails, the workers stay alive and can execute future commands:
try:
with mesh.activate():
# too big
big_w = torch.rand(4, 1024 * 1024 * 1024 * 1024 * 8, device="cuda")
v = t @ big_w
monarch.show(v)
except Exception:
import traceback
traceback.print_exc()
del big_w
print("RECOVERED!")
Since monarch recovers from errors, you can search for what works:
N = 1
while True:
try:
with mesh.activate():
batch = torch.rand(N, 1024 * 1024 * 1024, device="cuda")
monarch.inspect(batch.sum())
N = 2 * N
print(f"at least 2**{N} elements work")
except Exception:
print(f"max is 2**{N} elements")
break
Collectives#
Each machine has its own copy of the tensor, similar to torch.distributed.
To compute across tensors in the mesh, we use special communication operators, analogous to collectives.
with mesh.activate():
a = torch.rand(3, 4, device="cuda")
r = a.reduce("gpu", "sum")
monarch.show(a, gpu=0) # try
monarch.show(a, gpu=1) # try
monarch.show(r, gpu=0) # try
monarch.show(r, gpu=1) # try
Remote GPUs#
We can also connect to remote GPUs reserved from some scheduler
# NYI: schedule public API based on config, just fake it locally
remote_mesh = this_host().spawn_procs({"host": 4, "gpu": 4})
print(remote_mesh.to_table())
with remote_mesh.activate():
eg = torch.rand(3, 4, device="cuda")
rgpu = eg.reduce("gpu", "sum")
rhost = eg.reduce("host", "sum")
Device Mesh Dimensions#
Meshes can be renamed and reshaped to fit the parallelism desired.
mesh_2d_parallel = remote_mesh.rename(host="dp", gpu="tp")
print(mesh_2d_parallel.to_table())
mesh_3d_parallel = remote_mesh.split(host=("dp", "pp"), gpu=("tp",), pp=2)
print(mesh_3d_parallel.to_table())
Pipelining#
Pipelining is accomplished by slicing the mesh, and copying tensors from one mesh to another.
pipeline_mesh = remote_mesh.rename(host="pp")
meshes = [pipeline_mesh.slice(pp=i) for i in range(pipeline_mesh.size("pp"))]
print(meshes[0].to_table())
Initialize a model across multiple meshes
layers_per_stage = 2
stages = []
for stage_mesh in meshes:
with stage_mesh.activate():
layers = []
for _ in range(layers_per_stage):
layers.extend([nn.Linear(4, 4), nn.ReLU()])
stages.append(nn.Sequential(*layers))
def forward_pipeline(x):
with torch.no_grad():
for stage_mesh, stage in zip(meshes, stages):
x = x.to_mesh(stage_mesh)
with stage_mesh.activate():
x = stage(x)
return x
with meshes[0].activate():
input = torch.rand(3, 4, device="cuda")
output = forward_pipeline(input)
monarch.show(output)
print(output.mesh.to_table())
DDP Example#
The next sections will use an example of writing DDP to illustrate a typical way to develop code in monarch.
Let’s interleave the backward pass with the gradient reductions and parameter updates.
We use monarch.grad_generator to incrementally run the backward pass. It returns an iterator that computes the grad parameters one at a time.
def train(model, input, target):
loss = model(input, target)
rparameters = list(reversed(list(model.parameters())))
grads = monarch.grad_generator(loss, rparameters)
with torch.no_grad():
it = iter(zip(rparameters, grads))
todo = next(it, None)
while todo is not None:
param, grad = todo
grad.reduce_("dp", "sum")
todo = next(it, None)
param += 0.01 * grad
Simulation of DDP#
We can use a simulator to check for expected behavior of code before running it for real.
It is another kind of mesh, which simulates rather than computes results for real.
class Net(nn.Module):
def __init__(self):
super().__init__()
layers = []
for x in range(8):
layers.append(nn.Linear(4, 4))
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, input, target):
output = self.layers(input)
return torch.nn.functional.cross_entropy(output, target)
def simulate():
simulator = monarch.Simulator(hosts=1, gpus=4, trace_mode="stream_only")
mesh = simulator.mesh.rename(gpu="dp")
with mesh.activate():
model = Net()
train(model, torch.rand(3, 4), torch.full((3,), 1, dtype=torch.int64))
simulator.display()
simulate()
Overlapping Comms/Compute#
Commands on different devices run in parallel, but by default commands on a single device run sequentially.
We introduce parallelism on a device via stream objects.
main = monarch.get_active_stream()
comms = monarch.Stream("comms")
The main stream runs computation sequentially, while the comms stream
runs communication (e.g. allreduce) in parallel on the same device.
To use a tensor from one stream on another we borrow it. The borrow API ensures deterministic memory usage, and eliminates the race conditions in the torch.cuda.stream API.
A borrow transfers a tensor from one stream to another. The original stream
cannot use the tensor until borrow.drop() is called, ensuring no races.
The DDP example again, but using multiple streams.
def train(model, input, target):
loss = model(input, target)
rparameters = list(reversed(list(model.parameters())))
grads = monarch.grad_generator(loss, rparameters)
with torch.no_grad():
# NEW: iter also produces the tensor borrowed
# to the comm stream
it = iter(
(param, grad, *comms.borrow(grad, mutable=True))
for param, grad in zip(rparameters, grads)
)
todo = next(it, None)
while todo is not None:
param, grad, comm_grad, borrow = todo
# NEW: compute the reduce on the comm stream
with comms.activate():
comm_grad.reduce_("dp", "sum")
borrow.drop()
todo = next(it, None)
param += 0.01 * grad
simulate()
The simulation result showed the results did not overlap much due to where the borrow.drop was placed. The reduce on the comms stream completed before the next backward step started on the main stream.
The goal is to overlap the reduce (comms stream) with the next backward step (main stream). We can achieve this by ending the borrow after the grad step but before we update the param.
def train(model, input, target):
loss = model(input, target)
rparameters = list(reversed(list(model.parameters())))
grads = monarch.grad_generator(loss, rparameters)
with torch.no_grad():
it = iter(
(param, grad, *comms.borrow(grad, mutable=True))
for param, grad in zip(rparameters, grads)
)
todo = next(it, None)
while todo is not None:
param, grad, comm_grad, borrow = todo
with comms.activate():
comm_grad.reduce_("dp", "sum")
todo = next(it, None)
# NEW: delay the borrow as late as possible
borrow.drop()
param += 0.01 * grad
simulate()
Total running time of the script: (0 minutes 0.000 seconds)