Rate this Page

Source code for forge.actors.reference_model

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging
import math
import os
from collections.abc import Mapping
from dataclasses import dataclass, field, fields

import torch
from forge.controller import ForgeActor
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.rl.loss import compute_logprobs, create_shifted_targets
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.tensor import DTensor
from torchtitan.config.job_config import (
    Checkpoint,
    Comm,
    Compile,
    Model,
    Parallelism,
    Training,
)
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs] @dataclass class ReferenceModel(ForgeActor): """ A reference model actor for reinforcement learning (RL) training. Based on TorchTitan's engine architecture, this actor provides a frozen model that only runs forward passes without gradient computation. It is typically used to maintain algorithmic consistency in policy optimization methods such as GRPO (Group Relative Policy Optimization) or PPO (Proximal Policy Optimization), where it serves as a fixed reference point to compute KL divergence penalties against the training policy. The reference model is loaded from a checkpoint and runs in evaluation mode with inference_mode enabled to optimize memory and compute efficiency. Attributes: model (Model): Model configuration (architecture, vocab size, etc.) parallelism (Parallelism): Parallelism strategy configuration (TP, PP, CP, DP) checkpoint (Checkpoint): Checkpoint loading configuration compile (Compile): Torch compilation settings comm (Comm): Communication backend configuration training (Training): Training-related settings (dtype, garbage collection, etc.) """ # Refer to titan JobConfig for enabling more ForgeEngine configuration model: Model = field(default_factory=Model) parallelism: Parallelism = field(default_factory=Parallelism) checkpoint: Checkpoint = field(default_factory=Checkpoint) compile: Compile = field(default_factory=Compile) comm: Comm = field(default_factory=Comm) training: Training = field( default_factory=Training ) # Needed in order to set attrs like dtype, garbage collection freq, etc. # Populated in setup # TODO: Commented out since engine_config parsing extracts from class members # engine: ForgeEngine | None = None def __post_init__(self): """Initializes config types and env variables.""" super().__init__() # Instantiate dict fields for f in fields(self): attr = getattr(self, f.name) if isinstance(attr, Mapping): setattr(self, f.name, f.type(**attr)) elif not isinstance(attr, f.type): raise TypeError( f"{f.name} should be a {f.type} type or a dict like object" ) self.step = 0 self.rank = current_rank().rank self.size = math.prod(current_size().values()) self.compute_logprobs = compute_logprobs if self.compile.enable: self.compute_logprobs = torch.compile(self.compute_logprobs) env = { "RANK": str(self.rank), "LOCAL_RANK": str(self.rank), "LOCAL_WORLD_SIZE": str(self.size), "GROUP_RANK": str(self.size), "GROUP_WORLD_SIZE": str(self.size), "ROLE_RANK": str(self.rank), "ROLE_WORLD_SIZE": str(self.size), "ROLE_NAME": "rank", "WORLD_SIZE": str(self.size), "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", } os.environ.update(env) @endpoint async def setup(self): engine_config = {f.name: getattr(self, f.name) for f in fields(self)} engine_config = ForgeJobConfig(**engine_config) engine_config.checkpoint.folder = ( "" # hardcode to empty to force load from initial_load_path ) self.engine = ForgeEngine(engine_config) self.engine.checkpointer.load() self.model = self.engine.model_parts[0] # No pipeline parallelism yet self.model.eval() @endpoint async def forward( self, input_ids: torch.Tensor, return_logprobs: bool = True ) -> torch.Tensor: """ Args: input_ids (torch.Tensor): input token ids with shape [group_size, seq_len]. return_logprobs (bool): whether to return log probabilities instead of raw logits. return_logprobs flag significantly impacts the amount of data transferred to the caller: - When False: Returns logits with shape [group_size, seq_len, vocab_size]. This includes the full vocabulary distribution for each token position. - When True: Returns log probabilities with shape [group_size, seq_len]. Prompt positions will have logprobs = 0. """ record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM) t = Tracer("reference_perf/forward", timer="gpu", track_memory=True) t.start() self.engine.gc_handler.run(self.step) model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims input_ids = input_ids.to("cuda") # optional_context_parallel_ctx = ( # dist_utils.create_context_parallel_ctx( # cp_mesh=parallel_dims.world_mesh["cp"], # cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], # cp_seq_dims=[1, 1] + [0 for _ in model_parts], # cp_no_restore_buffers={inputs, labels}, # cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, # ) # if parallel_dims.cp_enabled # else None # ) optional_context_parallel_ctx = None if self.engine.parallel_dims.pp_enabled: raise NotImplementedError("PP not implemented yet") else: # (jackkhuu) Not sure if either context are needed for inference here with self.engine.train_context(optional_context_parallel_ctx): with self.engine.maybe_enable_amp: with torch.inference_mode(): logits = self.model(input_ids) if return_logprobs: target_ids = create_shifted_targets(input_ids) logprobs, _ = self.compute_logprobs(logits, target_ids) out = logprobs if return_logprobs else logits if isinstance(out, DTensor): out = out.full_tensor() self.step += 1 t.stop() return out