update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[source]¶
Validates the state dict for checkpoint loading for a classifier model. To be used prior to a call to
model.load_state_dict(state_dict). This function will overwrite theoutput.weightin the state-dict to be loaded with theoutput.weightin the model if the shapes for theoutput.weightdo not match. You may also wish to override this behaviour, for example, ifnum_classesfor your checkpoint and model are the same.Concretely, when fine-tuning a classifier model from the checkpoint of a base language model which has
output.weightof shape[vocab_dim, embed_dim], we overwrite theoutput.weightin the state-dict to be loaded with the randomly initialized[num_classes, embed_dim]weight in the model. This is done in-place.- Parameters:
state_dict (Dict[str, torch.Tensor]) – state dict to be loaded into the classifier model.
model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – model named parameters from
model.named_parameters().force_override (bool) – Whether to replace
output.weightinstate_dictwith the model’soutput.weight, even if the shapes match.
Notes
output.biaswill be ignored if present instate_dict- This function will always replace the
output.weightinstate_dict, if
output.weight != model.output.weight.
- This function will always replace the
- Raises:
AssertionError – if
state_dictdoes not containoutput.weight.AssertionError – if
model_named_parametersdoes not containoutput.weight.