_modules/forge/actors/vllm/v1/generator
Run in Google Colab
Colab
Download Notebook
Notebook
View on GitHub
GitHub
Source code for forge.actors.vllm.v1.generator
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import asyncio
import base64
import logging
import os
import time
import uuid
from collections.abc import Mapping
from dataclasses import dataclass, field
from multiprocessing import resource_tracker
from typing import Any, Optional
import cloudpickle
import torch
import torchstore as ts
from forge.actors._torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)
from forge.controller import ForgeActor
from forge.controller.provisioner import _get_provisioner
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
from monarch.actor import endpoint, ProcMesh, this_host
from torchstore.api import _controller as get_torchstore_controller
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.llm import UsageContext
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Suppress noisy vLLM "Added request" logs
logging.getLogger("vllm.v1.engine.async_llm").setLevel(logging.WARNING)
[docs]
@dataclass
class Generator(ForgeActor):
"""vLLM-based generator using AsyncLLM with Monarch distributed execution.
Wraps vLLM's AsyncLLM engine and uses MonarchExecutor for multi-GPU inference.
See MonarchExecutor docstring for architecture diagram.
Args:
engine_args: vLLM EngineArgs for model configuration. Can be EngineArgs or dict.
sampling_params: Default SamplingParams for generation. Can be SamplingParams or dict.
prefetch_weights_to_shm: Whether to prefetch weights to shared memory for faster
weight updates. When enabled, weight fetchers download weights in parallel
to shared memory while generation is still running. Defaults to True.
n_fetcher_procs: Number of fetcher processes for parallel weight downloading.
Only used when prefetch_weights_to_shm is True. Defaults to 8.
Example:
>>> generator = await Generator.options(procs=1, with_gpus=True).as_service(
... engine_args={"model": "meta-llama/Llama-3-8B", "tensor_parallel_size": 2},
... sampling_params={"max_tokens": 128, "temperature": 0.7},
... )
>>> completions = await generator.generate("Tell me a joke")
>>> await generator.shutdown()
"""
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
prefetch_weights_to_shm: bool = True
n_fetcher_procs: int = 8
def __post_init__(self):
super().__init__()
self.llm: Optional[AsyncLLM] = None
self.generator_version: int = 0
self.workers: Any = None # Workers ActorMesh, registered by MonarchExecutor
self.fetcher_procs: Optional[ProcMesh] = None # Fetcher proc mesh
self.weight_fetchers: Any = None # Weight fetcher ActorMesh
if isinstance(self.engine_args, Mapping):
self.engine_args = EngineArgs(**self.engine_args)
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Generator"],
*args,
**kwargs,
) -> "Generator":
"""Custom launch for Generator with proper GPU host provisioning.
Flow:
1. Get host mesh directly (via get_host_mesh for remote, this_host for local)
2. Spawn CPU proc on head host for Generator and WorkerRegistry
3. Allocate GPUs from provisioner
4. Pass host_mesh and GPU IDs to setup() - executor creates proc_mesh
"""
engine_args = kwargs.get("engine_args", {})
if isinstance(engine_args, Mapping):
engine_args = EngineArgs(**engine_args)
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
num_gpus = vllm_config.parallel_config.world_size
num_hosts = cls.hosts if cls.hosts else None
gpus_per_host = num_gpus // (num_hosts or 1)
mesh_name = cls.mesh_name or "generator"
# Step 1: Get host mesh directly (no bootstrap procs)
provisioner = await _get_provisioner()
if provisioner is None:
raise RuntimeError(
"Provisioner not initialized. Call init_provisioner() first."
)
if num_hosts:
# Remote allocation - get pre-allocated mesh from launcher
host_mesh = await provisioner.get_host_mesh(mesh_name)
else:
# Local allocation
host_mesh = this_host()
# Step 1a: Mount wsfuse on all hosts for local model paths
# This must happen before any actors are spawned that might access the path
if num_hosts and provisioner.launcher:
# Spawn temporary procs on all hosts just for mounting
mount_procs = host_mesh.spawn_procs(
per_host={"procs": 1}, name="mount_setup"
)
await provisioner.launcher.remote_setup(mount_procs)
logger.info("Completed remote_setup (mounted wsfuse on all hosts)")
# Step 1b: Allocate GPUs from provisioner
# This ensures each Generator gets exclusive GPU IDs
gpu_ids = await provisioner.allocate_gpu_ids(host_mesh, num_gpus)
logger.info(f"[Generator.launch] Allocated GPUs: {gpu_ids}")
logger.info(
f"[Generator.launch] Provisioned host mesh with {num_hosts or 1} host(s), "
f"{gpus_per_host} GPUs per host"
)
# Step 2: Spawn CPU proc on head host for Generator and WorkerRegistry
if num_hosts:
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
head_host = host_mesh.slice(**singleton_slice)
else:
head_host = host_mesh
generator_proc = head_host.spawn_procs(
per_host={"procs": 1}, name="generator_proc"
)
logger.info("[Generator.launch] Spawned generator_proc on head host")
# Register LocalFetcherActor for generator_proc to enable metrics collection
if not FORGE_DISABLE_METRICS.get_value():
await get_or_create_metric_logger(generator_proc, process_name=mesh_name)
# Import WorkerRegistry here to avoid circular import with monarch_executor
from forge.actors.vllm.v1.monarch_executor import WorkerRegistry
# Spawn WorkerRegistry on CPU proc (same as Generator)
worker_registry = generator_proc.spawn("worker_registry", WorkerRegistry)
logger.info("[Generator.launch] Spawned WorkerRegistry on generator_proc")
actor_name = kwargs.pop("name", cls.__name__)
generator = generator_proc.spawn(
actor_name,
cls,
*args,
**kwargs,
)
# Attach for cleanup in Generator.shutdown()
generator._generator_proc = generator_proc
generator._worker_registry = worker_registry
# Step 3: Pass host_mesh and gpu_ids to setup() - executor will create proc_mesh
await generator.setup.call(host_mesh, worker_registry, gpu_ids)
return generator
@endpoint
async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
"""Initialize AsyncLLM with MonarchExecutor.
Receives a host_mesh from launch(). Serializes it for the EngineCore
subprocess. MonarchExecutor creates its own proc_mesh from host_mesh,
spawns workers, and registers them. After AsyncLLM initialization,
Generator queries the registry for workers.
Args:
host_mesh: HostMesh for GPU workers (executor will create proc_mesh from this)
worker_registry: WorkerRegistry ActorMesh for worker registration
gpu_ids: List of allocated GPU IDs (e.g., ["0", "1"])
"""
num_gpus = self.vllm_config.parallel_config.tensor_parallel_size
logger.info(f"Setting up AsyncLLM with {num_gpus} GPUs, allocated: {gpu_ids}")
# Set env var for MonarchExecutor subprocess (EngineCore)
os.environ["VLLM_MONARCH_GPU_IDS"] = ",".join(gpu_ids)
# Serialize host_mesh reference
serialized_host_mesh = base64.b64encode(cloudpickle.dumps(host_mesh)).decode(
"utf-8"
)
os.environ["VLLM_MONARCH_HOST_MESH"] = serialized_host_mesh
# Serialize WorkerRegistry reference
serialized_registry = base64.b64encode(
cloudpickle.dumps(worker_registry)
).decode("utf-8")
os.environ["VLLM_MONARCH_WORKER_REGISTRY"] = serialized_registry
# Serialize TorchStore Controller reference for workers to access torchstore
torchstore_controller = await get_torchstore_controller()
serialized_controller = base64.b64encode(
cloudpickle.dumps(torchstore_controller)
).decode("utf-8")
os.environ["VLLM_TORCHSTORE_CONTROLLER"] = serialized_controller
# Force 'spawn' multiprocessing method for Monarch actors.
# This follows vLLM's Ray integration pattern.
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Set the executor backend to ForgeMonarchExecutor via string path
# This avoids import deadlock when vLLM spawns EngineCore subprocess
self.vllm_config.parallel_config.distributed_executor_backend = (
"forge.actors.vllm.v1.forge_executor.ForgeMonarchExecutor"
)
# Set up prefetching configuration via additional_config
# There does not seem to be a real difference between pass by env var or via self.vllm_config
if self.prefetch_weights_to_shm:
self.vllm_config.additional_config["n_fetcher_procs"] = self.n_fetcher_procs
else:
self.vllm_config.additional_config["n_fetcher_procs"] = 0
from vllm.v1.executor.abstract import Executor
try:
self.llm = AsyncLLM(
vllm_config=self.vllm_config,
# this resolve to the MonarchExecutor class
executor_class=Executor.get_class(self.vllm_config),
log_stats=True,
)
logger.info(f"AsyncLLM initialized with {num_gpus} workers")
except Exception as e:
logger.error(f"AsyncLLM initialization failed: {e}")
raise
# Query the WorkerRegistry for workers that were registered by MonarchExecutor
# during _init_executor()
self.workers = await worker_registry.get_workers.call_one()
if self.workers is None:
raise RuntimeError(
"Workers not found in registry. "
"MonarchExecutor may have failed to register workers."
)
logger.info(f"Retrieved workers from registry: {self.workers}")
if self.prefetch_weights_to_shm:
self._spawn_fetchers()
def _spawn_fetchers(self):
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory.
This assumes the generator is on the same host as the worker and only works for
single host generators.
"""
fetcher_procs = this_host().spawn_procs(
per_host={"procs": self.n_fetcher_procs}
)
self.fetcher_procs = fetcher_procs
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
logger.info(
f"[Generator] Spawned {self.n_fetcher_procs} weight fetchers: {self.weight_fetchers}"
)
async def _fetch_weights(
self,
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.
Coordinates parallel weight fetching across fetcher processes using round-robin
distribution of parameters.
Args:
version: Policy version to fetch
Returns:
Dict mapping param names to SharedTensorHandle objects
"""
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]
n_fetchers = self.weight_fetchers.size()
def split_keys(keys):
return [keys[i::n_fetchers] for i in range(n_fetchers)]
futures = []
for i, names in enumerate(split_keys(hf_param_names)):
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
version=version, param_names=names
)
futures.append(fut)
sub_state_dicts = [await fut for fut in futures]
state_dict = {}
for sd in sub_state_dicts:
state_dict.update(sd)
logger.info(
f"[Generator] Fetched {len(state_dict)} weights for v{version} to shared memory"
)
return state_dict
async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
for handle in state_dict.values():
handle.drop()
@endpoint
async def generate(
self,
prompt: str,
*,
priority: int = 0,
sampling_params: SamplingParams | None = None,
) -> list[Completion]:
"""Generate a response for the given prompt.
Args:
prompt (str): The prompt to generate a response for.
priority (int, optional): The priority of the request. Defaults to 0.
sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
If not provided, uses self.sampling_params.
Returns:
list[Completion]: n completions from vLLM based on your prompt.
"""
t = Tracer("generator_perf/generate", timer="gpu")
t.start()
record_metric("generator/generate/count_requests", 1, Reduce.SUM)
if self.llm is None:
raise RuntimeError("Generator not initialized. Call setup() first.")
params = sampling_params or self.sampling_params
if params.output_kind is None:
params.output_kind = RequestOutputKind.FINAL_ONLY
# Use AsyncLLM's generate() method - it returns an async generator
# that yields RequestOutput objects. We only want the final output.
request_output = None
async for output in self.llm.generate(
prompt=prompt,
sampling_params=params,
request_id=str(uuid.uuid4()),
):
request_output = output # Keep last output (final one)
completions = self._to_completions(request_output, prompt)
record_metric(
"generator/generate/count_sequences_completed",
len(completions),
Reduce.SUM,
)
t.stop()
return completions
@endpoint
async def stop(self):
"""Stop the generator and cleanup local resources.
This method is idempotent and can be called multiple times safely.
Note: Remote worker cleanup happens in shutdown() which has access to the proxy.
"""
if self.fetcher_procs is not None:
try:
await self.fetcher_procs.stop()
logger.info("Stopped fetcher procs")
except Exception as e:
logger.warning(f"Error stopping fetcher procs: {e}")
self.fetcher_procs = None
if self.llm is not None:
logger.info("Stopping AsyncLLM")
self.llm.shutdown()
logger.info("AsyncLLM.shutdown() returned")
self.llm = None
logger.info("stop() complete")
[docs]
@classmethod
async def shutdown(cls, actor):
"""Shutdown the generator and cleanup all resources.
Cleanup order:
1. Stop AsyncLLM (triggers MonarchExecutor.shutdown() which destroys
process groups and stops proc_mesh)
2. Stop generator_proc
"""
try:
await actor.stop.call()
except Exception as e:
logger.warning(f"Error during actor.stop: {e}")
try:
if getattr(actor, "_generator_proc", None):
await actor._generator_proc.stop()
logger.info("Stopped generator proc")
except Exception as e:
logger.warning(f"Error during generator_proc stop: {e}")
logger.info("shutdown() complete")
@endpoint
async def update_weights(
self,
version: Optional[int] = None,
) -> None:
"""Update weights on the generator from torchstore.
This method:
1. Optionally starts prefetching weights to shared memory (overlaps with pause)
2. Pauses generation and waits for in-flight requests to complete
3. Updates weights on workers (from shared memory if prefetched, else from torchstore)
4. Resumes generation
When prefetch_weights_to_shm is enabled, weight fetching is started as an async task
BEFORE pause_generation(), overlapping I/O with in-flight request completion.
This reduces the time generation is paused.
Note: This is NOT the standard vLLM weight update approach. vLLM typically
uses `collective_rpc` on EngineClient, which internally routes calls to
workers via the executor. However, `collective_rpc` uses msgspec/msgpack
serialization which does not support arbitrary Python objects by default
(only with VLLM_ALLOW_INSECURE_SERIALIZATION=1). This makes it difficult to
pass complex objects like torchstore storage handles. Instead, we use a
monarch-native approach where the Generator actor directly calls the worker
mesh (`self.workers.update_weights`) via Monarch RPC, which uses cloudpickle
and natively supports Monarch actor references for torchstore integration.
Args:
version: Policy version to load from torchstore
"""
if self.llm is None:
raise RuntimeError("Generator not initialized. Call setup() first.")
logger.info(f"Starting weight update to v{version}")
# Start prefetching weights to shared memory (overlaps with pause)
fetch_task = None
if self.prefetch_weights_to_shm:
logger.info(f"[Generator] Starting prefetch for v{version}")
fetch_task = asyncio.create_task(self._fetch_weights(version or 0))
pause_start = time.perf_counter()
await self.llm.pause_generation(
wait_for_inflight_requests=True, clear_cache=True
)
pause_duration = time.perf_counter() - pause_start
record_metric(
"generator_perf/update_weights/pause_generation_duration_s",
pause_duration,
Reduce.MEAN,
)
try:
load_start = time.perf_counter()
if fetch_task is not None:
wait_fetch_start = time.perf_counter()
fetched_weights = await fetch_task
wait_fetch_duration = time.perf_counter() - wait_fetch_start
record_metric(
"generator_perf/update_weights/wait_fetch_weights_s",
wait_fetch_duration,
Reduce.MEAN,
)
await self.workers.apply_prefetched_weights.call(fetched_weights)
await self._drop_shared_memory(fetched_weights)
else:
# Direct fetch from torchstore
await self.workers.update_weights.call(version=version)
load_duration = time.perf_counter() - load_start
record_metric(
"generator_perf/update_weights/worker_load_weights_duration_s",
load_duration,
Reduce.MEAN,
)
self.generator_version = version
finally:
await self.llm.resume_generation()
logger.info(f"Weight update complete, now v{version}")
@endpoint
async def save_model_params(self):
"""Save model parameters before weight update, used for testing purposes only."""
logger.info("save model parameters for testing.")
await self.workers.save_model_params.call()
@endpoint
async def validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("start validating model parameters.")
return await self.workers.validate_model_params.call(validate_fn)
def _extract_logprobs(self, output) -> torch.Tensor | None:
"""Extract logprobs from vLLM output as a torch.Tensor.
Args:
output: vLLM CompletionOutput with optional logprobs.
Returns:
torch.Tensor of logprobs for each token, or None if not available.
"""
if output.logprobs is not None:
return torch.tensor(
[
top_k_dict[token].logprob
for token, top_k_dict in zip(output.token_ids, output.logprobs)
]
)
return None
def _to_completions(
self, request_output: RequestOutput, prompt: str
) -> list[Completion]:
"""Convert vLLM RequestOutput to forge Completion objects.
Args:
request_output: vLLM request output with completions.
prompt: Original prompt string.
Returns:
List of Completion objects.
"""
completions = []
for output in request_output.outputs:
completion = Completion(
prompt=to_prompt(prompt),
text=output.text,
prompt_ids=torch.tensor(
request_output.prompt_token_ids
if request_output.prompt_token_ids
else []
),
token_ids=torch.tensor(
output.token_ids if hasattr(output, "token_ids") else []
),
logprobs=self._extract_logprobs(output),
stop_reason=output.finish_reason,
generator_version=self.generator_version,
metadata={"num_cached_tokens": request_output.num_cached_tokens},
)
completions.append(completion)
return completions
@endpoint
async def _reset_prefix_cache(self):
await self.llm.reset_prefix_cache()
class _WeightFetcher(ForgeActor):
"""Fetches weights from torchstore and loads them into shared memory.
Spawned by Generator using this_host().spawn_procs() to ensure shared
memory IPC namespace is shared with workers. This is critical for POSIX
shared memory to be visible between fetchers and workers.
"""
@endpoint
async def fetch(
self,
*,
version: int,
param_names: list[str],
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and load them into shared memory.
Args:
version: Policy version
param_names: List of parameter names to fetch
Returns:
Dict mapping param names to SharedTensorHandle objects
"""
sd = {}
for name in param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
shared_tensor = SharedTensor(tensor=param)
handle = shared_tensor.get_handle()
# Unregister from resource tracker - Generator will handle cleanup via drop()
# Without this, the shared memory gets cleaned up when the fetcher process exits
resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory")
sd[name] = handle
shared_tensor.close() # Close fd but don't unlink (workers will use it)
del param # Explicitly free the tensor after copying to shared memory
return sd