Trainer#
The Trainer manages model training in TorchForge, built on top of TorchTitan. It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows.
TitanTrainer#
- class forge.actors.trainer.TitanTrainer(job=<factory>, model=<factory>, optimizer=<factory>, lr_scheduler=<factory>, training=<factory>, parallelism=<factory>, checkpoint=<factory>, activation_checkpoint=<factory>, compile=<factory>, quantize=<factory>, comm=<factory>, memory_estimation=<factory>, loss=<function TitanTrainer.<lambda>>, state_dict_key='model_state_dict', use_dcp=True, dcp_path='forge_dcp_tmp')[source]#
A generic trainer actor implementation built on top of TorchTitan.
Built on top of TorchTitan’s training engine, this actor provides a complete training loop for reinforcement learning. It performs forward and backward passes with gradient computation, optimization steps, and checkpoint management. Unlike the ReferenceModel actor which only runs forward passes, RLTrainer actively updates the policy model parameters through gradient descent.
The trainer supports the same distributed training strategies that TorchTitan does, including but not limited to, tensor parallelism, data parallelism, and FSDP (Fully Sharded Data Parallel). It is typically used in conjunction with ReferenceModel for policy optimization algorithms like GRPO (Group Relative Policy Optimization), where it optimizes the policy against a loss that includes KL divergence penalties from the reference model.
The trainer handles: - Forward and backward propagation with automatic mixed precision (AMP) - Optimizer steps with learning rate scheduling
- activation_checkpoint#
- checkpoint#
- cleanup#
- comm#
- compile#
- dcp_path = 'forge_dcp_tmp'#
- job#
- loss(**targets)#
- lr_scheduler#
- memory_estimation#
- model#
- optimizer#
- parallelism#
- push_weights#
- quantize#
- setup#
- state_dict_key = 'model_state_dict'#
- train_step#
- training#
- use_dcp = True#
Configuration#
The TitanTrainer uses TorchTitan’s configuration system with the following components:
Job Configuration#
- class torchtitan.config.job_config.Job(config_file=None, dump_folder='./torchtitan/outputs', description='default job', print_config=False, custom_config_module='')[source]#
- config_file = None#
Job config file
- custom_config_module = ''#
This option allows users to extend the existing JobConfig with a customized JobConfig dataclass. Users need to ensure that the path can be imported.
- description = 'default job'#
Description of the job
- dump_folder = './torchtitan/outputs'#
Folder to dump job outputs
- print_config = False#
Print the configs to terminal
Model Configuration#
- class torchtitan.config.job_config.Model(name='llama3', flavor='debugmodel', hf_assets_path='./tests/assets/tokenizer', tokenizer_path=None, converters=<factory>, print_after_conversion=False)[source]#
- converters#
Comma separated list of converters to apply to the model. For instance, the float8 converter swaps torch.nn.Linear with Float8Linear. This feature requires you to install ‘torchao’ which can be found here: pytorch/ao
- flavor = 'debugmodel'#
Which model config to train
- hf_assets_path = './tests/assets/tokenizer'#
Path to HF assets folder. This folder contains local copies of Hugging Face assets, including model weights in .safetensors format, the model.safetensor.index.json file (fqn to file mapping), the config.json file, generation_config.json, and tokenizer files.
- name = 'llama3'#
Which model to train
- print_after_conversion = False#
If true, model definition will be printed to stdout after all model converters have been applied.
- tokenizer_path = None#
Use hf_assets_path instead.
- Type:
DEPRECATED
Optimizer Configuration#
- class torchtitan.config.job_config.Optimizer(name='AdamW', lr=0.0008, beta1=0.9, beta2=0.95, eps=1e-08, weight_decay=0.1, implementation='fused', early_step_in_backward=False)[source]#
- beta1 = 0.9#
- beta2 = 0.95#
Exponential moving average hyperparameters to use
- early_step_in_backward = False#
Whether to apply optimizer in the backward. Caution, optimizer_in_backward is not compatible with gradients clipping, users should not call register_post_accumulate_grad_hook after the optimizer is built.
- eps = 1e-08#
Epsilon value to use
- implementation = 'fused'#
Specify which optimizer implementation to use: - ‘fused’: Use fused implementation (CUDA only) for best performance. - ‘foreach’: Use some horizontal fusion of tensors for better performance. - ‘for-loop’: Use the default implementation for the optimizer (slowest). - more info: https://pytorch.org/docs/stable/optim.html
- lr = 0.0008#
Learning rate to use
- name = 'AdamW'#
Optimizer to use
- weight_decay = 0.1#
Weight decay to use
Training Configuration#
- class torchtitan.config.job_config.Training(dataset='c4_test', dataset_path=None, local_batch_size=8, global_batch_size=-1, seq_len=2048, max_norm=1.0, steps=10000, enable_cpu_offload=False, dtype='float32', mixed_precision_param='bfloat16', mixed_precision_reduce='float32', gc_freq=50, gc_debug=False, seed=None, deterministic=False, debug_moe_force_load_balance=False)[source]#
- dataset = 'c4_test'#
Dataset to use
- dataset_path = None#
Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded.
- debug_moe_force_load_balance = False#
If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.
- deterministic = False#
Use deterministic algorithms wherever possible, may be slower
- dtype = 'float32'#
torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights. In the case of full bf16 training, RoPE calculations and logits will still be in fp32.
- enable_cpu_offload = False#
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP
- gc_debug = False#
Enable GC debugging mode. This will perform gc.collect() at every step to detect if there is a reference cycle that includes a CUDA Tensor. Note that you may want to lower the training steps to avoid generating too many temporary files.
- gc_freq = 50#
Python garbage control scheduling interval, in steps
- global_batch_size = -1#
Global batch size (defaults to training.local_batch_size * data-parallel degree)
- local_batch_size = 8#
Local batch size (i.e., per-device batch size)
- max_norm = 1.0#
Max norm for gradient clipping
- mixed_precision_param = 'bfloat16'#
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast. This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1 and no other parallelism is enabled, i.e. under DDP or single-device training.
- mixed_precision_reduce = 'float32'#
torch dtype to use for reductions when applying mixed precision via FSDP. This feature only takes effect when data_parallel_shard_degree > 1
- seed = None#
Choose the base RNG seed used for training
- seq_len = 2048#
Sequence length
- steps = 10000#
How many train steps to run
Parallelism Configuration#
- class torchtitan.config.job_config.Parallelism(data_parallel_replicate_degree=1, enable_compiled_autograd=False, data_parallel_shard_degree=-1, fsdp_reshard_after_forward='default', tensor_parallel_degree=1, disable_loss_parallel=False, enable_async_tensor_parallel=False, pipeline_parallel_degree=1, module_fqns_per_model_part=None, pipeline_parallel_first_stage_less_layers=1, pipeline_parallel_last_stage_less_layers=1, pipeline_parallel_layers_per_stage=None, pipeline_parallel_schedule='1F1B', pipeline_parallel_schedule_csv='', pipeline_parallel_microbatch_size=1, context_parallel_degree=1, context_parallel_rotate_method='allgather', expert_parallel_degree=1, expert_tensor_parallel_degree=1)[source]#
- context_parallel_degree = 1#
Context parallelism degree. 1 means disabled.
- context_parallel_rotate_method = 'allgather'#
The collective to use in context parallel SDPA for kv shards exchange. - ‘allgather’ means to all-gather all kv shards on ranks after the first sub-SDPA computation, - ‘alltoall’ means to all-to-all shuffle the kv shards. The default value is ‘allgather’.
- data_parallel_replicate_degree = 1#
The data_parallel_replicate_degree argument specifies the degree of data parallelism for weight replication. When this value is greater than 1, weights will be replicated across data_parallel_replicate_degree ranks. If data_parallel_shard_degree is also greater than 1, the parallelism method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the parallelism method used is DDP (Distributed Data Parallelism). 1 means disabled.
- data_parallel_shard_degree = -1#
The data_parallel_shard_degree argument specifies the degree of data parallelism for weight sharding. When this value is greater than 1, weights will be sharded across data_parallel_shard_degree ranks. If data_parallel_replicate_degree is also greater than 1, the parallelism method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the parallelism method used is FSDP (Fully Sharded Data Parallelism). -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that only data_parallel_shard_degree can be negative. 1 means disabled.
- disable_loss_parallel = False#
Whether to apply loss parallel when sequence parallel is enabled
- enable_async_tensor_parallel = False#
Whether to apply async tensor parallel (currently only effective when compile is enabled)
- enable_compiled_autograd = False#
Enable CompiledAutograd to compile the backward.
- expert_parallel_degree = 1#
Expert parallelism degree. 1 means disabled. No effect for non-MoE models.
Currently, it is supported with the following constraints:
when etp = tp:
cp <= ep <= dp_shard * cp
ep % cp == 0
dp_shard * cp % ep == 0
when etp = 1:
cp * tp <= ep <= dp_shard * cp * tp
ep % (cp * tp) == 0
dp_shard * cp * tp % ep == 0
Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support.
- expert_tensor_parallel_degree = 1#
Expert tensor parallelism degree. 1 means disabled. No effect for non-MoE models, or when ep = 1. With this option, the tensor parallel degree on routed experts can be different from that on other params. Currently, we only support either - [partial dp -> ep] etp = tp - [partial dp + all tp -> ep] etp = 1 Note that this is still an experimental feature.
- fsdp_reshard_after_forward = 'default'#
reshard_after_forward specifies the policy for applying reshard_after_forward within an FSDP setup. reshard_after_forward controls parameter behavior after forward, trading off memory and communication. See torch’s fully_shard API for more documentation on reshard_after_forward.
The supported policies include “default”, “always” and “never”:
“default” applies default resharding behavior, implementing “smart defaults” for known optimal
scenarios. - “always” will enable reshard_after_forward for all forward passes. - “never” will disable reshard_after_forward for all forward passes.
- module_fqns_per_model_part = None#
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk. Each inner list represents one model chunk and contains the module names that belong to that chunk. e.g. [[‘tok_embeddings’, ‘layers.0’], [‘layers.1’, ‘layers.2’], [‘layers.3’, ‘layers.4’]] will create 3 chunks: the first containing tok_embeddings and layers.0, the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. This provides more explicit control over which modules belong to each chunk compared to split points.
- pipeline_parallel_degree = 1#
Pipeline Parallelism degree, or number of ranks. 1 means disabled. If using looped schedules, this still specifies the number of physical ranks, not the number of stages. Stages per rank are inferred from split points degree, and schedule.
- pipeline_parallel_first_stage_less_layers = 1#
The number of layers to reduce in the first stage of pipeline parallelism. This is because the first stage has the extra overhead of the embedding layer, which is not present in the other stages.
- pipeline_parallel_last_stage_less_layers = 1#
The number of layers to reduce in the last stage of pipeline parallelism. This is because the last stage has the extra overhead of the output layer, which is not present in the other stages.
- pipeline_parallel_layers_per_stage = None#
The number of layers per (virtual) pipeline stage. If specified, the module_fqns_per_model_part will be calculated from the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will be inferred from the model, schedule, and pipeline_parallel_degree.
- pipeline_parallel_microbatch_size = 1#
The size of each pipeline parallel microbatch (default 1). This value is used to compute the total number of microbatches by dividing local_batch_size with pipeline_parallel_microbatch_size. The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
- pipeline_parallel_schedule = '1F1B'#
Specify the Pipeline Parallel schedule to use. The supported schedules are: pytorch/pytorch. The schedule must be compatible with the split points and stages_per_rank. Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks, and split_points = number of stages - 1
- pipeline_parallel_schedule_csv = ''#
Specify the path to the pipeline parallel schedule csv file to use. The pipeline_parallel_schedule argument must be either PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
- tensor_parallel_degree = 1#
Tensor Parallelism degree. 1 means disabled.
Checkpoint Configuration#
- class torchtitan.config.job_config.Checkpoint(enable=False, enable_ft_dataloader_checkpoints=True, folder='checkpoint', interval=500, initial_load_path=None, initial_load_model_only=True, initial_load_in_hf=False, initial_load_in_hf_quantized=False, last_save_model_only=True, last_save_in_hf=False, export_dtype='float32', async_mode='disabled', keep_latest_k=10, load_step=-1, exclude_from_loading=<factory>, enable_first_step_checkpoint=False, create_seed_checkpoint=False, load_only=False)[source]#
- async_mode = 'disabled'#
Which async checkpoint mode to use. Currently there are 3 different modes.
“disabled”: synchronized checkpointing will be used.
“async”: torch.distributed.checkpoint.async_save will be used.
“async_with_pinned_mem”: this option utilizes a dedicated pinned memory space and creates a
separate process for faster GPU->CPU transfer performance and eliminating GIL contention. The cost is increased CPU memory usage. If insufficient CPU memory is available, performance may degrade due to memory paging. For most users, “async” should suffice as the performance overhead is typically small (on the order of tens of seconds) compared to checkpointing frequency. This mode can be employed to pursue near-zero checkpointing times (e.g., < 1 second) given appropriate hardware support such as ample CPU memory and fast PCIe.
“disabled” is the default mode.
- create_seed_checkpoint = False#
Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint. Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1. Could be implemented as a separate script, but this way shares more code.
- enable = False#
Whether to enable checkpoint
- enable_first_step_checkpoint = False#
Enable the checkpoint save at first step. This will save a checkpoint immediately after the first step to ensure checkpointing functions correctly. This is useful when running on a new cluster or storage to verify checkpointing without waiting for many steps or checkpointing too frequently. The default value is False.
- enable_ft_dataloader_checkpoints = True#
Disabling this can have fault tolerant replicas training over the same data multiple times. Use it with caution if training over the same data is acceptable.
Used to enable checkpointing the dataloader index for fault tolerant training with torchft.
Fault tolerant training stores data loader index in the checkpoints, so that training can resume without going over the same batch twice.
If enabled, data loader state is checkpointed. Otherwise, replicas will train over the same data multiple times, which can result in overfitting.
The failed replcia will still recover other state e.g. model parameters from other replcias.
Note, if regular checkpointing is enabled, we also checkpoint the data loader state. But when not using fault tolerance, the entire training starts from scratch.
- Type:
Warning
- exclude_from_loading#
Exclude specific keys from being loaded from the checkpoint. Provide a comma-separated list of keys to exclude, e.g. ‘optimizer,lr_scheduler,dataloader’. This will load the model only, excluding the specified keys.
- export_dtype = 'float32'#
Converts to the specified precision when training completes and last_save_model_only=true.
- folder = 'checkpoint'#
The folder to store the checkpoints. When enable is set to true, checkpoints will be in {–job.dump_folder}/{–checkpoint.folder}.
- initial_load_in_hf = False#
Enable the use of HuggingFace’s safetensors format for checkpointing. The option is only used when initial_load_path is specified. This will load checkpoints in HF’s model definition and safetensors format instead of the default torchtitan model definition and DCP format, after necessary model state dict transformation. initial_load_model_only must be true because safetensors doesn’t support saving non-tensors. The default value is False.
- initial_load_in_hf_quantized = False#
Enable loading of HuggingFace’s safetensors format with quantized state dict keys. The option is only used when initial_load_path and initial_load_path_in_hf is specified. This will load checkpoints in HF’s model definition and dequantize on model weights if necessary. To support this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize.
- initial_load_model_only = True#
This option specifies if only the model should be loaded during the initial checkpoint load. The option is only used when initial_load_path is specified. If False, the checkpoint at initial_load_path is treated as a standard training checkpoint, including optimizer, lr scheduler, training states, etc. The default setting for this option is True. Note that you will have to use –checkpoint.no_initial_load_model_only to override the default setting.
- initial_load_path = None#
This option specifies the path to the initial checkpoint to load, which is particularly useful for resuming training from a previous run with a different output path or when loading a checkpoint from a pre-trained model. If the checkpoint folder for the current run is not empty, located at {–job.dump_folder}/{–checkpoint.folder}, this option will be ignored. This feature allows users to load an initial checkpoint from a different folder and continue training, saving new checkpoints to the specified folder without affecting the existing ones.
Note that the path should contain the full path to the checkpoint folder, including the step number, if any; for example, “//pre_train/checkpoints/llama3/llama3_8b/step_10000”.
- interval = 500#
Checkpointing interval in steps.
- keep_latest_k = 10#
Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints. K cannot be 1 as the last one may be in the process of being saved. As a result, the metadata of the last one may not be ready yet. The default value is 10 to avoid filling up the disk.
- last_save_in_hf = False#
Enable the use of Hugging Face’s safetensors format for checkpointing. This will save the final checkpoints in safetensors format instead of the default DCP format, after necessary model state dict transformation. There will be a performance cost in using this as we need to consolidate the sharded tensors to full tensors as a separate step. last_save_model_only must be true because safetensors doesn’t support saving non-tensors. On load, this argument isn’t needed as we will detect whether the loaded checkpoint is in safetensors format or not. The default value is False.
- last_save_model_only = True#
When last_save_model_only=True, only the model will be saved at the end of training, the last save. With this, checkpoints can be loaded using torch.load(…, weights_only=True) after conversion. When last_save_model_only=False, the full checkpoint will be saved. A full checkpoint includes model, optimizer and train_state, which can be used to resume training. The default value is True.
- load_only = False#
In certain scenarios, you may only need to load checkpoints for verification or debugging purposes, without saving any new checkpoints. For example, you might use seed checkpoints to validate model correctness. Enabling this option allows checkpoints to be loaded without saving any during the training.
- load_step = -1#
Load the checkpoint at the specified step. If -1, load the latest checkpoint.