Note
Go to the end to download the full example code
Distributed PPO-like Reinforcement Learning with Monarch Actors#
This example demonstrates implementing a distributed PPO-like reinforcement learning algorithm using the Monarch actor framework. The implementation features: - Distributed actor architecture with Generator, Scorer, and Learner components - Asynchronous communication via queues - RDMA-based weight synchronization - Event-driven architecture for efficient processing The example shows how to: - Set up distributed actors on separate GPU meshes - Implement policy gradient methods in a distributed setting - Use RDMA buffers for efficient parameter sharing - Create an asynchronous training loop with multiple components
import asyncio
import copy
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from monarch.actor import Actor, endpoint, proc_mesh
from monarch.rdma import RDMABuffer
from torch.distributions import Categorical, kl_divergence
"""
Online reinforcement learning (RL) training loop using the Monarch actor framework.
This example implements a distributed PPO-like algorithm with three main components:
1. Generator: Produces actions using the current policy and sends them for scoring
2. Scorer: Evaluates actions and assigns rewards
3. Learner: Updates policy based on collected experiences
Key features demonstrated:
- Distributed actors on separate GPU meshes
- Asynchronous communication via queues
- RDMA-based weight synchronization
- Event-driven architecture
"""
G = 8 # group size
STATE_DIM = 4
ACTION_DIM = 4 # vocab size
@dataclass
class TrajectorySlice:
"""Single trajectory from one generator call.
Attributes:
policy_version: Version of policy that produced this slice
state: Input state tensor [STATE_DIM]
actions: Generated actions [G]
old_logps: Log probabilities of actions under generation policy [G]
rewards: Rewards for each action (initially zeros, filled by Scorer) [G]
"""
policy_version: int
state: torch.Tensor
actions: torch.Tensor
old_logps: torch.Tensor
rewards: torch.Tensor
@dataclass
class TrainingBatch:
"""Batch of trajectories for training.
Attributes:
states: Batched states [batch_size, STATE_DIM]
actions: Batched actions [batch_size * G]
old_logps: Batched log probabilities [batch_size * G]
rewards: Batched rewards [batch_size * G]
policy_versions: List of policy versions for each slice
"""
states: torch.Tensor
actions: torch.Tensor
old_logps: torch.Tensor
rewards: torch.Tensor
policy_versions: List[int]
class TrajectoryQueue(Actor):
"""Queue for trajectory slices between Generator and Scorer."""
def __init__(self):
"""Initialize an empty queue."""
self.queue: asyncio.Queue[TrajectorySlice] = asyncio.Queue()
@endpoint
async def put(self, slice: TrajectorySlice) -> None:
"""Add a trajectory slice to the queue.
Args:
slice: The trajectory slice to add
"""
await self.queue.put(slice)
@endpoint
async def get(self) -> TrajectorySlice:
"""Remove and return a trajectory slice from the queue.
Returns:
The next trajectory slice in the queue
"""
return await self.queue.get()
class ReplayBuffer(Actor):
"""Storage for scored trajectory slices with weighted sampling."""
def __init__(self):
"""Initialize an empty buffer."""
self.storage: List[Tuple[int, TrajectorySlice]] = [] # (version, slice)
self.storage_event = asyncio.Event()
@endpoint
async def put(self, slice: TrajectorySlice) -> None:
"""Add a trajectory slice to the buffer.
Args:
slice: The trajectory slice to add
"""
self.storage.append((slice.policy_version, slice))
self.storage_event.set()
async def _wait_for_storage(self):
if not self.storage:
await self.storage_event.wait()
@endpoint
async def sample_from(self, k: int) -> List[TrajectorySlice]:
"""Sample k trajectory slices using weighted sampling.
Items from newer policy versions have higher probability of being selected.
If the buffer is empty, waits for it to be populated with a timeout.
Args:
k: Number of slices to sample
Returns:
List of sampled trajectory slices
Raises:
RuntimeError: If buffer is empty after timeout
"""
try:
await asyncio.wait_for(self._wait_for_storage(), timeout=10.0)
except asyncio.TimeoutError:
raise RuntimeError("Timeout waiting for ReplayBuffer to be populated")
# Extract policy versions and add 1 to ensure all weights are positive
policy_versions = [version + 1 for version, _ in self.storage]
# Use policy versions as weights for sampling
total = sum(policy_versions)
probs = [v / total for v in policy_versions]
# Sample indices based on policy version weights
indices = list(range(len(self.storage)))
chosen_indices = random.choices(indices, weights=probs, k=k)
return [self.storage[i][1] for i in chosen_indices]
class Scorer(Actor):
"""Evaluates actions and assigns rewards to trajectory slices."""
def __init__(self, trajectory_queue: Any, replay_buffer: Any):
"""Initialize the scorer.
Args:
trajectory_queue: Queue to pull trajectory slices from
replay_buffer: Buffer to store scored slices in
"""
self.trajectory_queue = trajectory_queue
self.replay_buffer = replay_buffer
self.net = nn.Sequential(
nn.Linear(STATE_DIM + 1, 8),
nn.Tanh(),
nn.Linear(8, 1),
).to("cuda")
self.running = False
async def _score_slice(self, slice: TrajectorySlice) -> None:
"""Score a trajectory slice and store it in the replay buffer.
Args:
slice: The trajectory slice to score
"""
s = slice.state.to("cuda").unsqueeze(0).repeat(G, 1)
a = slice.actions.to("cuda").float().unsqueeze(-1)
rewards = self.net(torch.cat([s, a], dim=-1)).squeeze(-1).cpu()
scored = TrajectorySlice(
policy_version=slice.policy_version,
state=slice.state,
actions=slice.actions,
old_logps=slice.old_logps,
rewards=rewards,
)
await self.replay_buffer.put.call(scored)
@endpoint
async def run(self) -> None:
"""Start the scoring event loop.
Continuously pulls slices from the queue, scores them,
and puts them in the replay buffer until stopped.
"""
if self.running:
return
self.running = True
try:
while self.running:
try:
slice_ = await asyncio.wait_for(
self.trajectory_queue.get.call_one(),
timeout=1.0,
)
await self._score_slice(slice_)
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"Scorer event loop error: {e}")
finally:
self.running = False
@endpoint
async def stop(self) -> None:
"""Stop the scoring event loop."""
self.running = False
class Learner(Actor):
"""Updates policy based on collected experiences using PPO algorithm."""
def __init__(self, replay_buffer: Any):
"""Initialize the learner.
Args:
replay_buffer: Buffer to sample experiences from
"""
# Policy network and reference network for KL divergence
self.model = nn.Sequential(
nn.Linear(STATE_DIM, 16), nn.Tanh(), nn.Linear(16, ACTION_DIM)
).to("cuda")
self.ref_model = copy.deepcopy(self.model)
for p in self.ref_model.parameters():
p.requires_grad = False
self.ref_model.eval()
# Optimization parameters
self.optim = optim.Adam(self.model.parameters(), lr=1e-3, eps=1e-5)
self.eps = 0.2 # PPO clipping parameter
self.kl_coeff = 0.1 # KL divergence coefficient
self.policy_version = 0
self.replay_buffer = replay_buffer
self.batch_size = 2
self.generators: Optional[Any] = None
@endpoint
async def init_generators(self, generators: Any) -> None:
"""Set the generators service for weight updates.
Args:
generators: Service to notify of policy updates
"""
self.generators = generators
@endpoint
async def weights_handle(self) -> Dict[str, RDMABuffer]:
"""Create RDMA buffers for model weights.
Returns:
Dictionary mapping parameter names to RDMA buffers
"""
return {
k: RDMABuffer(v.view(torch.uint8).flatten())
for k, v in self.model.state_dict().items()
}
def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
"""Compute advantages from rewards.
In PPO, advantages represent how much better an action is compared to the average.
Here we compute advantages by subtracting a baseline (mean reward) from the rewards
and then normalizing to stabilize training.
Args:
rewards: Raw rewards tensor [batch_size * G]
Returns:
Advantages tensor [batch_size * G]
"""
# First, reshape rewards to [batch_size, G] to compute per-state baseline
batch_size = rewards.shape[0] // G
rewards_reshaped = rewards.view(batch_size, G)
# Compute baseline (mean reward) for each state
baselines = rewards_reshaped.mean(dim=1, keepdim=True) # [batch_size, 1]
# Subtract baseline from rewards to get advantages
advantages = rewards_reshaped - baselines # [batch_size, G]
# Reshape back to original shape
advantages = advantages.reshape(-1) # [batch_size * G]
# Normalize advantages for training stability
if advantages.numel() > 1: # Check if we have more than one element
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return advantages
def _apply_policy_update(
self,
states: torch.Tensor,
actions: torch.Tensor,
old_logps: torch.Tensor,
advantages: torch.Tensor,
) -> torch.Tensor:
"""Apply PPO update to policy network.
Args:
states: Batch of states
actions: Batch of actions
old_logps: Log probabilities from old policy
advantages: Normalized advantages
Returns:
Loss value
"""
# Compute new policy distribution and log probabilities
dist_new = Categorical(logits=self.model(states))
new_logps = dist_new.log_prob(actions)
# PPO clipped objective
ratio = (new_logps - old_logps).exp()
unclipped = ratio * advantages
clipped = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantages
ppo_loss = -torch.min(unclipped, clipped).mean()
# KL penalty to prevent large policy updates
with torch.no_grad():
ref_logits = self.ref_model(states)
kl = kl_divergence(Categorical(logits=ref_logits), dist_new).mean()
# Update policy
loss = ppo_loss + self.kl_coeff * kl
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optim.step()
self.policy_version += 1
return loss.detach()
@endpoint
async def step(self) -> torch.Tensor:
"""Perform one training step.
Returns:
Loss value from the update
"""
# Notify generators of current policy version
if self.generators:
await self.generators.update.call(self.policy_version)
# Sample and process trajectory slices
slices = await self.replay_buffer.sample_from.call_one(self.batch_size)
raw_states = torch.stack([s.state for s in slices])
actions = torch.cat([s.actions for s in slices])
old_logps = torch.cat([s.old_logps for s in slices])
rewards = torch.cat([s.rewards for s in slices])
# Prepare tensors for update
states = raw_states.repeat_interleave(G, 0).to("cuda")
actions, old_logps, rewards = [
x.to("cuda") for x in (actions, old_logps, rewards)
]
# Compute advantages and update policy
advs = self._compute_advantages(rewards)
return self._apply_policy_update(states, actions, old_logps, advs)
class GeneratorState:
"""States for the Generator's state machine."""
READY_TO_GENERATE = "READY_TO_GENERATE"
READY_TO_UPDATE = "READY_TO_UPDATE"
class Generator(Actor):
"""Generates actions using the current policy.
Maintains a copy of the policy network that is synchronized with the Learner
via RDMA buffers. Generates actions for given states and sends them to the
trajectory queue for scoring.
"""
def __init__(self, weight_buffers, trajectory_queue):
"""Initialize the generator.
Args:
weight_buffers: RDMA buffers for policy weights
trajectory_queue: Queue to put generated trajectories in
"""
self.model = nn.Sequential(
nn.Linear(STATE_DIM, 16), nn.Tanh(), nn.Linear(16, ACTION_DIM)
).to("cuda")
self.weight_buffers = weight_buffers
self.trajectory_queue = trajectory_queue
self.state = GeneratorState.READY_TO_GENERATE
self.cond = asyncio.Condition()
self.policy_version = 0
@endpoint
async def generate(self, state: torch.Tensor) -> None:
"""Generate actions for a given state.
Args:
state: Input state tensor [STATE_DIM]
"""
async with self.cond:
# Wait until ready to generate
await self.cond.wait_for(
lambda: self.state == GeneratorState.READY_TO_GENERATE
)
# Generate actions using current policy
x = state.to("cuda").unsqueeze(0).repeat(G, 1)
dist = Categorical(logits=self.model(x))
acts = dist.sample()
logps = dist.log_prob(acts)
# Create trajectory slice
slice_ = TrajectorySlice(
self.policy_version,
state,
acts,
logps,
torch.zeros(G),
)
# Send to trajectory queue for scoring
await self.trajectory_queue.put.call(slice_)
async with self.cond:
# Signal ready for update
self.state = GeneratorState.READY_TO_UPDATE
self.cond.notify_all()
@endpoint
async def update(self, version: int) -> None:
"""Update policy weights from RDMA buffers.
Args:
version: New policy version number
"""
async with self.cond:
# Copy weights from RDMA buffers
sd = self.model.state_dict()
for n, b in self.weight_buffers.items():
await b.read_into(sd[n].view(torch.uint8).flatten())
self.model.load_state_dict(sd)
# Update version and state
self.policy_version = version
self.state = GeneratorState.READY_TO_GENERATE
self.cond.notify_all()
async def main():
"""Run the distributed reinforcement learning training loop."""
# Create process meshes for different components
learner_mesh = await proc_mesh(gpus=1)
gen_mesh = await proc_mesh(gpus=2)
# Spawn actors on the learner mesh
traj_q = await learner_mesh.spawn("traj", TrajectoryQueue)
replay_buf = await learner_mesh.spawn("rb", ReplayBuffer)
learner = await learner_mesh.spawn("learner", Learner, replay_buf)
scorer = await learner_mesh.spawn("scorer", Scorer, traj_q, replay_buf)
# Get weight buffers and spawn generators on the generator mesh
wb = await learner.weights_handle.call_one()
generators = await gen_mesh.spawn(
"generator",
Generator,
wb,
traj_q,
)
await learner.init_generators.call(generators)
# Start the scorer event loop in the background
scorer_run_future = scorer.run.call_one()
# Training loop
for step in range(5):
state = torch.randn(STATE_DIM)
# Generate actions and update policy in parallel
_, loss = await asyncio.gather(
generators.generate.call(state),
learner.step.call_one(),
)
print(f"[Step {step:02d}] loss={loss:.3f}")
# Clean up
await scorer.stop.call_one()
await scorer_run_future
print("✅ done")
if __name__ == "__main__":
asyncio.run(main())
Total running time of the script: (0 minutes 0.000 seconds)