_modules/forge/actors/reference_model
Run in Google Colab
Colab
Download Notebook
Notebook
View on GitHub
GitHub
Source code for forge.actors.reference_model
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import os
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
import torch
from forge.controller import ForgeActor
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.rl.loss import compute_logprobs, create_shifted_targets
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.tensor import DTensor
from torchtitan.config.job_config import (
Checkpoint,
Comm,
Compile,
Model,
Parallelism,
Training,
)
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]
@dataclass
class ReferenceModel(ForgeActor):
"""
A reference model actor for reinforcement learning (RL) training.
Based on TorchTitan's engine architecture, this actor provides a
frozen model that only runs forward passes without gradient
computation. It is typically used to maintain algorithmic
consistency in policy optimization methods such as GRPO
(Group Relative Policy Optimization) or PPO (Proximal Policy
Optimization), where it serves as a fixed reference point to
compute KL divergence penalties against the training policy.
The reference model is loaded from a checkpoint and runs in
evaluation mode with inference_mode enabled to optimize memory and
compute efficiency.
Attributes:
model (Model): Model configuration (architecture, vocab size,
etc.)
parallelism (Parallelism): Parallelism strategy configuration
(TP, PP, CP, DP)
checkpoint (Checkpoint): Checkpoint loading configuration
compile (Compile): Torch compilation settings
comm (Comm): Communication backend configuration
training (Training): Training-related settings (dtype, garbage
collection, etc.)
"""
# Refer to titan JobConfig for enabling more ForgeEngine configuration
model: Model = field(default_factory=Model)
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
compile: Compile = field(default_factory=Compile)
comm: Comm = field(default_factory=Comm)
training: Training = field(
default_factory=Training
) # Needed in order to set attrs like dtype, garbage collection freq, etc.
# Populated in setup
# TODO: Commented out since engine_config parsing extracts from class members
# engine: ForgeEngine | None = None
def __post_init__(self):
"""Initializes config types and env variables."""
super().__init__()
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
if isinstance(attr, Mapping):
setattr(self, f.name, f.type(**attr))
elif not isinstance(attr, f.type):
raise TypeError(
f"{f.name} should be a {f.type} type or a dict like object"
)
self.step = 0
self.rank = current_rank().rank
self.size = math.prod(current_size().values())
self.compute_logprobs = compute_logprobs
if self.compile.enable:
self.compute_logprobs = torch.compile(self.compute_logprobs)
env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
"LOCAL_WORLD_SIZE": str(self.size),
"GROUP_RANK": str(self.size),
"GROUP_WORLD_SIZE": str(self.size),
"ROLE_RANK": str(self.rank),
"ROLE_WORLD_SIZE": str(self.size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self.size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
@endpoint
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
engine_config = ForgeJobConfig(**engine_config)
engine_config.checkpoint.folder = (
"" # hardcode to empty to force load from initial_load_path
)
self.engine = ForgeEngine(engine_config)
self.engine.checkpointer.load()
self.model = self.engine.model_parts[0] # No pipeline parallelism yet
self.model.eval()
@endpoint
async def forward(
self, input_ids: torch.Tensor, return_logprobs: bool = True
) -> torch.Tensor:
"""
Args:
input_ids (torch.Tensor): input token ids with shape [group_size, seq_len].
return_logprobs (bool): whether to return log probabilities instead of raw logits.
return_logprobs flag significantly impacts the amount of data transferred to the caller:
- When False: Returns logits with shape [group_size, seq_len, vocab_size].
This includes the full vocabulary distribution for each token position.
- When True: Returns log probabilities with shape [group_size, seq_len].
Prompt positions will have logprobs = 0.
"""
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
t = Tracer("reference_perf/forward", timer="gpu", track_memory=True)
t.start()
self.engine.gc_handler.run(self.step)
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
# cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
# cp_seq_dims=[1, 1] + [0 for _ in model_parts],
# cp_no_restore_buffers={inputs, labels},
# cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
# )
# if parallel_dims.cp_enabled
# else None
# )
optional_context_parallel_ctx = None
if self.engine.parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet")
else:
# (jackkhuu) Not sure if either context are needed for inference here
with self.engine.train_context(optional_context_parallel_ctx):
with self.engine.maybe_enable_amp:
with torch.inference_mode():
logits = self.model(input_ids)
if return_logprobs:
target_ids = create_shifted_targets(input_ids)
logprobs, _ = self.compute_logprobs(logits, target_ids)
out = logprobs if return_logprobs else logits
if isinstance(out, DTensor):
out = out.full_tensor()
self.step += 1
t.stop()
return out