Rate this Page

Source code for forge.actors.vllm.v1.generator

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import asyncio
import base64
import logging
import os
import time
import uuid
from collections.abc import Mapping
from dataclasses import dataclass, field
from multiprocessing import resource_tracker
from typing import Any, Optional

import cloudpickle
import torch
import torchstore as ts
from forge.actors._torchstore_utils import (
    extract_param_name,
    get_param_key,
    get_param_prefix,
)
from forge.controller import ForgeActor
from forge.controller.provisioner import _get_provisioner
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
from monarch.actor import endpoint, ProcMesh, this_host
from torchstore.api import _controller as get_torchstore_controller
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.llm import UsageContext
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Suppress noisy vLLM "Added request" logs
logging.getLogger("vllm.v1.engine.async_llm").setLevel(logging.WARNING)


[docs] @dataclass class Generator(ForgeActor): """vLLM-based generator using AsyncLLM with Monarch distributed execution. Wraps vLLM's AsyncLLM engine and uses MonarchExecutor for multi-GPU inference. See MonarchExecutor docstring for architecture diagram. Args: engine_args: vLLM EngineArgs for model configuration. Can be EngineArgs or dict. sampling_params: Default SamplingParams for generation. Can be SamplingParams or dict. prefetch_weights_to_shm: Whether to prefetch weights to shared memory for faster weight updates. When enabled, weight fetchers download weights in parallel to shared memory while generation is still running. Defaults to True. n_fetcher_procs: Number of fetcher processes for parallel weight downloading. Only used when prefetch_weights_to_shm is True. Defaults to 8. Example: >>> generator = await Generator.options(procs=1, with_gpus=True).as_service( ... engine_args={"model": "meta-llama/Llama-3-8B", "tensor_parallel_size": 2}, ... sampling_params={"max_tokens": 128, "temperature": 0.7}, ... ) >>> completions = await generator.generate("Tell me a joke") >>> await generator.shutdown() """ engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) prefetch_weights_to_shm: bool = True n_fetcher_procs: int = 8 def __post_init__(self): super().__init__() self.llm: Optional[AsyncLLM] = None self.generator_version: int = 0 self.workers: Any = None # Workers ActorMesh, registered by MonarchExecutor self.fetcher_procs: Optional[ProcMesh] = None # Fetcher proc mesh self.weight_fetchers: Any = None # Weight fetcher ActorMesh if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Generator"], *args, **kwargs, ) -> "Generator": """Custom launch for Generator with proper GPU host provisioning. Flow: 1. Get host mesh directly (via get_host_mesh for remote, this_host for local) 2. Spawn CPU proc on head host for Generator and WorkerRegistry 3. Allocate GPUs from provisioner 4. Pass host_mesh and GPU IDs to setup() - executor creates proc_mesh """ engine_args = kwargs.get("engine_args", {}) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) num_gpus = vllm_config.parallel_config.world_size num_hosts = cls.hosts if cls.hosts else None gpus_per_host = num_gpus // (num_hosts or 1) mesh_name = cls.mesh_name or "generator" # Step 1: Get host mesh directly (no bootstrap procs) provisioner = await _get_provisioner() if provisioner is None: raise RuntimeError( "Provisioner not initialized. Call init_provisioner() first." ) if num_hosts: # Remote allocation - get pre-allocated mesh from launcher host_mesh = await provisioner.get_host_mesh(mesh_name) else: # Local allocation host_mesh = this_host() # Step 1a: Mount wsfuse on all hosts for local model paths # This must happen before any actors are spawned that might access the path if num_hosts and provisioner.launcher: # Spawn temporary procs on all hosts just for mounting mount_procs = host_mesh.spawn_procs( per_host={"procs": 1}, name="mount_setup" ) await provisioner.launcher.remote_setup(mount_procs) logger.info("Completed remote_setup (mounted wsfuse on all hosts)") # Step 1b: Allocate GPUs from provisioner # This ensures each Generator gets exclusive GPU IDs gpu_ids = await provisioner.allocate_gpu_ids(host_mesh, num_gpus) logger.info(f"[Generator.launch] Allocated GPUs: {gpu_ids}") logger.info( f"[Generator.launch] Provisioned host mesh with {num_hosts or 1} host(s), " f"{gpus_per_host} GPUs per host" ) # Step 2: Spawn CPU proc on head host for Generator and WorkerRegistry if num_hosts: singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} head_host = host_mesh.slice(**singleton_slice) else: head_host = host_mesh generator_proc = head_host.spawn_procs( per_host={"procs": 1}, name="generator_proc" ) logger.info("[Generator.launch] Spawned generator_proc on head host") # Register LocalFetcherActor for generator_proc to enable metrics collection if not FORGE_DISABLE_METRICS.get_value(): await get_or_create_metric_logger(generator_proc, process_name=mesh_name) # Import WorkerRegistry here to avoid circular import with monarch_executor from forge.actors.vllm.v1.monarch_executor import WorkerRegistry # Spawn WorkerRegistry on CPU proc (same as Generator) worker_registry = generator_proc.spawn("worker_registry", WorkerRegistry) logger.info("[Generator.launch] Spawned WorkerRegistry on generator_proc") actor_name = kwargs.pop("name", cls.__name__) generator = generator_proc.spawn( actor_name, cls, *args, **kwargs, ) # Attach for cleanup in Generator.shutdown() generator._generator_proc = generator_proc generator._worker_registry = worker_registry # Step 3: Pass host_mesh and gpu_ids to setup() - executor will create proc_mesh await generator.setup.call(host_mesh, worker_registry, gpu_ids) return generator @endpoint async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]): """Initialize AsyncLLM with MonarchExecutor. Receives a host_mesh from launch(). Serializes it for the EngineCore subprocess. MonarchExecutor creates its own proc_mesh from host_mesh, spawns workers, and registers them. After AsyncLLM initialization, Generator queries the registry for workers. Args: host_mesh: HostMesh for GPU workers (executor will create proc_mesh from this) worker_registry: WorkerRegistry ActorMesh for worker registration gpu_ids: List of allocated GPU IDs (e.g., ["0", "1"]) """ num_gpus = self.vllm_config.parallel_config.tensor_parallel_size logger.info(f"Setting up AsyncLLM with {num_gpus} GPUs, allocated: {gpu_ids}") # Set env var for MonarchExecutor subprocess (EngineCore) os.environ["VLLM_MONARCH_GPU_IDS"] = ",".join(gpu_ids) # Serialize host_mesh reference serialized_host_mesh = base64.b64encode(cloudpickle.dumps(host_mesh)).decode( "utf-8" ) os.environ["VLLM_MONARCH_HOST_MESH"] = serialized_host_mesh # Serialize WorkerRegistry reference serialized_registry = base64.b64encode( cloudpickle.dumps(worker_registry) ).decode("utf-8") os.environ["VLLM_MONARCH_WORKER_REGISTRY"] = serialized_registry # Serialize TorchStore Controller reference for workers to access torchstore torchstore_controller = await get_torchstore_controller() serialized_controller = base64.b64encode( cloudpickle.dumps(torchstore_controller) ).decode("utf-8") os.environ["VLLM_TORCHSTORE_CONTROLLER"] = serialized_controller # Force 'spawn' multiprocessing method for Monarch actors. # This follows vLLM's Ray integration pattern. os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Set the executor backend to ForgeMonarchExecutor via string path # This avoids import deadlock when vLLM spawns EngineCore subprocess self.vllm_config.parallel_config.distributed_executor_backend = ( "forge.actors.vllm.v1.forge_executor.ForgeMonarchExecutor" ) # Set up prefetching configuration via additional_config # There does not seem to be a real difference between pass by env var or via self.vllm_config if self.prefetch_weights_to_shm: self.vllm_config.additional_config["n_fetcher_procs"] = self.n_fetcher_procs else: self.vllm_config.additional_config["n_fetcher_procs"] = 0 from vllm.v1.executor.abstract import Executor try: self.llm = AsyncLLM( vllm_config=self.vllm_config, # this resolve to the MonarchExecutor class executor_class=Executor.get_class(self.vllm_config), log_stats=True, ) logger.info(f"AsyncLLM initialized with {num_gpus} workers") except Exception as e: logger.error(f"AsyncLLM initialization failed: {e}") raise # Query the WorkerRegistry for workers that were registered by MonarchExecutor # during _init_executor() self.workers = await worker_registry.get_workers.call_one() if self.workers is None: raise RuntimeError( "Workers not found in registry. " "MonarchExecutor may have failed to register workers." ) logger.info(f"Retrieved workers from registry: {self.workers}") if self.prefetch_weights_to_shm: self._spawn_fetchers() def _spawn_fetchers(self): """Spawn weight fetchers that prefetch weights from torchstore to shared memory. This assumes the generator is on the same host as the worker and only works for single host generators. """ fetcher_procs = this_host().spawn_procs( per_host={"procs": self.n_fetcher_procs} ) self.fetcher_procs = fetcher_procs self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) logger.info( f"[Generator] Spawned {self.n_fetcher_procs} weight fetchers: {self.weight_fetchers}" ) async def _fetch_weights( self, version: int, ) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}. Coordinates parallel weight fetching across fetcher processes using round-robin distribution of parameters. Args: version: Policy version to fetch Returns: Dict mapping param names to SharedTensorHandle objects """ prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) hf_param_names = [extract_param_name(key) for key in matching_keys] n_fetchers = self.weight_fetchers.size() def split_keys(keys): return [keys[i::n_fetchers] for i in range(n_fetchers)] futures = [] for i, names in enumerate(split_keys(hf_param_names)): fut = self.weight_fetchers.slice(procs=i).fetch.call_one( version=version, param_names=names ) futures.append(fut) sub_state_dicts = [await fut for fut in futures] state_dict = {} for sd in sub_state_dicts: state_dict.update(sd) logger.info( f"[Generator] Fetched {len(state_dict)} weights for v{version} to shared memory" ) return state_dict async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): for handle in state_dict.values(): handle.drop() @endpoint async def generate( self, prompt: str, *, priority: int = 0, sampling_params: SamplingParams | None = None, ) -> list[Completion]: """Generate a response for the given prompt. Args: prompt (str): The prompt to generate a response for. priority (int, optional): The priority of the request. Defaults to 0. sampling_params (SamplingParams, optional): Sampling parameters to use for this request. If not provided, uses self.sampling_params. Returns: list[Completion]: n completions from vLLM based on your prompt. """ t = Tracer("generator_perf/generate", timer="gpu") t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) if self.llm is None: raise RuntimeError("Generator not initialized. Call setup() first.") params = sampling_params or self.sampling_params if params.output_kind is None: params.output_kind = RequestOutputKind.FINAL_ONLY # Use AsyncLLM's generate() method - it returns an async generator # that yields RequestOutput objects. We only want the final output. request_output = None async for output in self.llm.generate( prompt=prompt, sampling_params=params, request_id=str(uuid.uuid4()), ): request_output = output # Keep last output (final one) completions = self._to_completions(request_output, prompt) record_metric( "generator/generate/count_sequences_completed", len(completions), Reduce.SUM, ) t.stop() return completions @endpoint async def stop(self): """Stop the generator and cleanup local resources. This method is idempotent and can be called multiple times safely. Note: Remote worker cleanup happens in shutdown() which has access to the proxy. """ if self.fetcher_procs is not None: try: await self.fetcher_procs.stop() logger.info("Stopped fetcher procs") except Exception as e: logger.warning(f"Error stopping fetcher procs: {e}") self.fetcher_procs = None if self.llm is not None: logger.info("Stopping AsyncLLM") self.llm.shutdown() logger.info("AsyncLLM.shutdown() returned") self.llm = None logger.info("stop() complete")
[docs] @classmethod async def shutdown(cls, actor): """Shutdown the generator and cleanup all resources. Cleanup order: 1. Stop AsyncLLM (triggers MonarchExecutor.shutdown() which destroys process groups and stops proc_mesh) 2. Stop generator_proc """ try: await actor.stop.call() except Exception as e: logger.warning(f"Error during actor.stop: {e}") try: if getattr(actor, "_generator_proc", None): await actor._generator_proc.stop() logger.info("Stopped generator proc") except Exception as e: logger.warning(f"Error during generator_proc stop: {e}") logger.info("shutdown() complete")
@endpoint async def update_weights( self, version: Optional[int] = None, ) -> None: """Update weights on the generator from torchstore. This method: 1. Optionally starts prefetching weights to shared memory (overlaps with pause) 2. Pauses generation and waits for in-flight requests to complete 3. Updates weights on workers (from shared memory if prefetched, else from torchstore) 4. Resumes generation When prefetch_weights_to_shm is enabled, weight fetching is started as an async task BEFORE pause_generation(), overlapping I/O with in-flight request completion. This reduces the time generation is paused. Note: This is NOT the standard vLLM weight update approach. vLLM typically uses `collective_rpc` on EngineClient, which internally routes calls to workers via the executor. However, `collective_rpc` uses msgspec/msgpack serialization which does not support arbitrary Python objects by default (only with VLLM_ALLOW_INSECURE_SERIALIZATION=1). This makes it difficult to pass complex objects like torchstore storage handles. Instead, we use a monarch-native approach where the Generator actor directly calls the worker mesh (`self.workers.update_weights`) via Monarch RPC, which uses cloudpickle and natively supports Monarch actor references for torchstore integration. Args: version: Policy version to load from torchstore """ if self.llm is None: raise RuntimeError("Generator not initialized. Call setup() first.") logger.info(f"Starting weight update to v{version}") # Start prefetching weights to shared memory (overlaps with pause) fetch_task = None if self.prefetch_weights_to_shm: logger.info(f"[Generator] Starting prefetch for v{version}") fetch_task = asyncio.create_task(self._fetch_weights(version or 0)) pause_start = time.perf_counter() await self.llm.pause_generation( wait_for_inflight_requests=True, clear_cache=True ) pause_duration = time.perf_counter() - pause_start record_metric( "generator_perf/update_weights/pause_generation_duration_s", pause_duration, Reduce.MEAN, ) try: load_start = time.perf_counter() if fetch_task is not None: wait_fetch_start = time.perf_counter() fetched_weights = await fetch_task wait_fetch_duration = time.perf_counter() - wait_fetch_start record_metric( "generator_perf/update_weights/wait_fetch_weights_s", wait_fetch_duration, Reduce.MEAN, ) await self.workers.apply_prefetched_weights.call(fetched_weights) await self._drop_shared_memory(fetched_weights) else: # Direct fetch from torchstore await self.workers.update_weights.call(version=version) load_duration = time.perf_counter() - load_start record_metric( "generator_perf/update_weights/worker_load_weights_duration_s", load_duration, Reduce.MEAN, ) self.generator_version = version finally: await self.llm.resume_generation() logger.info(f"Weight update complete, now v{version}") @endpoint async def save_model_params(self): """Save model parameters before weight update, used for testing purposes only.""" logger.info("save model parameters for testing.") await self.workers.save_model_params.call() @endpoint async def validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("start validating model parameters.") return await self.workers.validate_model_params.call(validate_fn) def _extract_logprobs(self, output) -> torch.Tensor | None: """Extract logprobs from vLLM output as a torch.Tensor. Args: output: vLLM CompletionOutput with optional logprobs. Returns: torch.Tensor of logprobs for each token, or None if not available. """ if output.logprobs is not None: return torch.tensor( [ top_k_dict[token].logprob for token, top_k_dict in zip(output.token_ids, output.logprobs) ] ) return None def _to_completions( self, request_output: RequestOutput, prompt: str ) -> list[Completion]: """Convert vLLM RequestOutput to forge Completion objects. Args: request_output: vLLM request output with completions. prompt: Original prompt string. Returns: List of Completion objects. """ completions = [] for output in request_output.outputs: completion = Completion( prompt=to_prompt(prompt), text=output.text, prompt_ids=torch.tensor( request_output.prompt_token_ids if request_output.prompt_token_ids else [] ), token_ids=torch.tensor( output.token_ids if hasattr(output, "token_ids") else [] ), logprobs=self._extract_logprobs(output), stop_reason=output.finish_reason, generator_version=self.generator_version, metadata={"num_cached_tokens": request_output.num_cached_tokens}, ) completions.append(completion) return completions @endpoint async def _reset_prefix_cache(self): await self.llm.reset_prefix_cache()
class _WeightFetcher(ForgeActor): """Fetches weights from torchstore and loads them into shared memory. Spawned by Generator using this_host().spawn_procs() to ensure shared memory IPC namespace is shared with workers. This is critical for POSIX shared memory to be visible between fetchers and workers. """ @endpoint async def fetch( self, *, version: int, param_names: list[str], ) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and load them into shared memory. Args: version: Policy version param_names: List of parameter names to fetch Returns: Dict mapping param names to SharedTensorHandle objects """ sd = {} for name in param_names: param_key = get_param_key(version, name) param = await ts.get(param_key) shared_tensor = SharedTensor(tensor=param) handle = shared_tensor.get_handle() # Unregister from resource tracker - Generator will handle cleanup via drop() # Without this, the shared memory gets cleaned up when the fetcher process exits resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory") sd[name] = handle shared_tensor.close() # Close fd but don't unlink (workers will use it) del param # Explicitly free the tensor after copying to shared memory return sd