Note
Go to the end to download the full example code
Mock actors and training loop for TorchForge GRPO simulation.
This module provides lightweight mock implementations of TorchForge actors that can be used with Monarch’s patch_actor mechanism to test GRPO training without requiring GPUs, vLLM, or TorchTitan.
See grpo_forge_sim.py for usage examples.
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
# TorchForge actor imports (requires conda environment with Forge installed)
from forge.actors.generator import Generator as ForgeGenerator
from forge.actors.reference_model import ReferenceModel as ForgeReferenceModel
from forge.actors.replay_buffer import ReplayBuffer as ForgeReplayBuffer
from forge.actors.trainer import TitanTrainer as ForgeTitanTrainer
from forge.rl.advantage import ComputeAdvantages as ForgeComputeAdvantages
from forge.rl.grading import RewardActor as ForgeRewardActor
from monarch.actor import Actor, endpoint, this_host
# ==============================================================================
# Mock Actor Implementations
# ==============================================================================
class MockGenerator(Actor):
"""Mock generator that returns dummy completions without vLLM."""
def __init__(self, **kwargs: Any) -> None:
self._step = 0
@endpoint
async def setup(self) -> None:
print("[MockGenerator] setup (mocked)")
@endpoint
async def generate(self, prompt: str) -> List[Any]:
"""Return mock completions."""
@dataclass
class MockCompletion:
text: str
prompt_ids: torch.Tensor
token_ids: torch.Tensor
stop_reason: str = "eos"
generator_version: int = 0
return [
MockCompletion(
text=f"Mock answer {i}: 42",
prompt_ids=torch.zeros(10, dtype=torch.long),
token_ids=torch.zeros(20, dtype=torch.long),
generator_version=self._step,
)
for i in range(4)
]
@endpoint
async def update_weights(self, step: int) -> None:
print(f"[MockGenerator] update_weights -> step {step}")
self._step = step
class MockTitanTrainer(Actor):
"""Mock trainer that simulates training steps."""
def __init__(self, **kwargs: Any) -> None:
self._step = 0
@endpoint
async def setup(self) -> None:
print("[MockTitanTrainer] setup (mocked)")
@endpoint
async def train_step(self, inputs: torch.Tensor, targets: Dict[str, Any]) -> None:
self._step += 1
print(f"[MockTitanTrainer] train_step #{self._step}")
@endpoint
async def push_weights(self, step: int) -> None:
print(f"[MockTitanTrainer] push_weights step={step}")
class MockReferenceModel(Actor):
"""Mock reference model that returns dummy logprobs."""
def __init__(self, **kwargs: Any) -> None:
pass
@endpoint
async def setup(self) -> None:
print("[MockReferenceModel] setup (mocked)")
@endpoint
async def forward(
self,
input_ids: torch.Tensor,
max_req_tokens: int,
return_logprobs: bool = True,
) -> torch.Tensor:
batch_size = input_ids.shape[0]
response_len = max(input_ids.shape[1] - max_req_tokens, 1)
return torch.zeros(batch_size, response_len)
class MockRewardActor(Actor):
"""Mock reward actor that returns varied rewards."""
def __init__(self, **kwargs: Any) -> None:
self._call_count = 0
@endpoint
async def setup(self) -> None:
print("[MockRewardActor] setup (mocked)")
@endpoint
async def evaluate_response(
self, prompt: str, response: str, target: str
) -> tuple[Dict[str, float], float]:
self._call_count += 1
reward = random.uniform(0.5, 1.0)
return ({"mock_reward": reward}, reward)
class MockReplayBuffer(Actor):
"""Mock replay buffer."""
def __init__(self, **kwargs: Any) -> None:
self._buffer: List[Any] = []
@endpoint
async def setup(self) -> None:
print("[MockReplayBuffer] setup (mocked)")
@endpoint
async def add(self, episode: Any) -> None:
self._buffer.append(episode)
@endpoint
async def sample(self, curr_policy_version: int) -> Optional[Any]:
if len(self._buffer) < 4:
return None
return (torch.zeros(4, 32), {"advantages": torch.ones(4)})
class MockComputeAdvantages(Actor):
"""Mock advantage computation."""
def __init__(self, **kwargs: Any) -> None:
pass
@endpoint
async def setup(self) -> None:
print("[MockComputeAdvantages] setup (mocked)")
@endpoint
async def compute(self, episodes: List[Any]) -> List[float]:
return [1.0] * len(episodes)
# ==============================================================================
# GRPO Training Loop
# ==============================================================================
async def grpo_training_loop(
generator: Any,
trainer: Any,
ref_model: Any,
reward_actor: Any,
replay_buffer: Any,
compute_advantages: Any,
num_steps: int = 3,
) -> None:
"""Simplified GRPO training loop mirroring TorchForge's apps/grpo/main.py."""
print("\n" + "-" * 60)
print("GRPO Training Loop")
print("-" * 60 + "\n")
for step in range(num_steps):
prompt = f"What is {step + 1} + {step + 1}?"
responses = await generator.generate.call_one(prompt)
print(f"[Step {step}] Generated {len(responses)} responses")
for i, response in enumerate(responses):
reward_breakdown, reward = await reward_actor.evaluate_response.call_one(
prompt=prompt, response=response.text, target=str((step + 1) * 2)
)
print(f" Response {i}: reward={reward:.3f}")
input_ids = torch.zeros(len(responses), 32, dtype=torch.long)
await ref_model.forward.call_one(
input_ids, max_req_tokens=10, return_logprobs=True
)
episodes = [{"response": r, "reward": 1.0} for r in responses]
await compute_advantages.compute.call_one(episodes)
for episode in episodes:
await replay_buffer.add.call_one(episode)
batch = await replay_buffer.sample.call_one(curr_policy_version=step)
if batch is not None:
inputs, targets = batch
await trainer.train_step.call(inputs, targets)
await trainer.push_weights.call(step)
await generator.update_weights.call_one(step)
print(f"[Step {step}] Complete\n")
print("-" * 60)
print("Training Complete!")
print("-" * 60)
async def main() -> None:
"""
Main entry point that spawns actors and runs the training loop.
This function spawns real TorchForge actors. To run with mocks,
use patch_actor decorators as shown in grpo_forge_sim.py.
"""
host = this_host()
proc_mesh = host.spawn_procs(per_host={"procs": 1})
print("Spawning actors...")
generator = proc_mesh.spawn("generator", ForgeGenerator)
trainer = proc_mesh.spawn("trainer", ForgeTitanTrainer)
ref_model = proc_mesh.spawn("ref_model", ForgeReferenceModel)
reward_actor = proc_mesh.spawn("reward_actor", ForgeRewardActor)
replay_buffer = proc_mesh.spawn("replay_buffer", ForgeReplayBuffer)
compute_advantages = proc_mesh.spawn("compute_advantages", ForgeComputeAdvantages)
print("Setting up actors...")
await generator.setup.call()
await trainer.setup.call()
await ref_model.setup.call()
await reward_actor.setup.call()
await replay_buffer.setup.call()
await compute_advantages.setup.call()
await grpo_training_loop(
generator=generator,
trainer=trainer,
ref_model=ref_model,
reward_actor=reward_actor,
replay_buffer=replay_buffer,
compute_advantages=compute_advantages,
num_steps=3,
)
Total running time of the script: (0 minutes 0.000 seconds)