Model#
The forge.actors.reference_model.ReferenceModel provides a frozen
copy of the policy model used for computing advantages in reinforcement
learning. It performs inference on input sequences and returns logits or
log probabilities for computing KL divergence and other RL metrics.
ReferenceModel#
- class forge.actors.reference_model.ReferenceModel(model=<factory>, parallelism=<factory>, checkpoint=<factory>, compile=<factory>, comm=<factory>, training=<factory>)[source]#
Bases:
ForgeActorA reference model actor for reinforcement learning (RL) training.
Based on TorchTitan’s engine architecture, this actor provides a frozen model that only runs forward passes without gradient computation. It is typically used to maintain algorithmic consistency in policy optimization methods such as GRPO (Group Relative Policy Optimization) or PPO (Proximal Policy Optimization), where it serves as a fixed reference point to compute KL divergence penalties against the training policy.
The reference model is loaded from a checkpoint and runs in evaluation mode with inference_mode enabled to optimize memory and compute efficiency.
- Variables:
model (Model) – Model configuration (architecture, vocab size,
etc.) (collection,) –
parallelism (Parallelism) – Parallelism strategy configuration
DP) ((TP, PP, CP,) –
checkpoint (Checkpoint) – Checkpoint loading configuration
compile (Compile) – Torch compilation settings
comm (Comm) – Communication backend configuration
training (Training) – Training-related settings (dtype, garbage
etc.) –
- checkpoint#
- comm#
- compile#
- forward#
- model#
- parallelism#
- setup#
- training#
The ReferenceModel uses a subset of TorchTitan’s configuration system:
model: Model architecture settings (Model dataclass)
parallelism: Parallelism configuration for distributed inference (Parallelism dataclass)
checkpoint: Checkpoint loading settings (Checkpoint dataclass)
compile: Model compilation settings (Compile dataclass)
training: Training configuration for dtype and other settings (Training dataclass)
For detailed configuration options, refer to the TorchTitan documentation.