TransformerDecoder¶
- class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]¶
Transformer Decoder derived from the Llama2 architecture.
- Parameters:
tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int) – maximum sequence length the model will be run with, as used by
KVCache()num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the
KVCache()head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the
KVCache()norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.
output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.
num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.
output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output
- Raises:
AssertionError – If
num_layersis set and layer is a list, ornum_layersis not set and layer is annn.Module.
Note
Arg values are checked for correctness (eg:
attn_dropoutbelongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.- caches_are_enabled() bool[source]¶
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be “enabled” and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by “disabling” the KV-caches using
torchtune.modules.common_utils.disable_kv_cache(), upon whichcaches_are_enabledwould return False.
- caches_are_setup() bool[source]¶
Check if the key value caches are setup. This means
setup_cacheshas been called, and the relevant attention modules in the model have created theirKVCache.
- chunked_output(last_hidden_state: Tensor) List[Tensor][source]¶
Apply output projection in chunks. This should be applied in conjunction with
CEWithChunkedOutputLossas upcasting to fp32 is done there.To use this method, you should first call
set_num_output_chunks().- Parameters:
last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].
- Returns:
- List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
- Return type:
List[torch.Tensor]
- forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][source]¶
- Parameters:
tokens (torch.Tensor) – input tensor with shape
[b x s]mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:
A boolean tensor with shape
[b x s x s],[b x s x self.encoder_max_cache_seq_len], or[b x s x self.encoder_max_cache_seq_len]if using KV-cacheing with encoder/decoder layers. A value of True in rowiand columnjmeans tokeniattends to tokenj. A value of False means tokenidoes not attend to tokenj. If no mask is specified, a causal mask is used by default.A
BlockMaskfor document masking in a packed sequence created via create_block_mask. We useflex_attention()when computing attention with block masks. Default is None.encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape
[b x s_e x d_e]encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position
i,jmeans tokenican attend to embeddingjin the decoder. Mask has shape[b x s x s_e]. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape
[b x s]. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.
- Returns:
- output tensor with shape
[b x s x v]or a list of layer output tensors defined by
output_hidden_stateswith the final output tensor appended to the list.
- output tensor with shape
- Return type:
Union[torch.Tensor, List[torch.Tensor]]
Note
At the very first step of inference, when the model is provided with a prompt,
input_posshould contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will betorch.arange(prompt_length). For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape[b, padded_prompt_length]. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches,input_poswill contain the position(s) of the current token(s)torch.tensor([padded_prompt_length]). Otherwise,input_poswill contain all the position ids up to the current token.- Shape notation:
b: batch size
s: token sequence length
s_e: encoder sequence length
v: vocab size
d: token embed dim
d_e: encoder embed dim
m_s: max seq len
- reset_caches()[source]¶
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, without deleting or reallocating cache tensors.
- Raises:
RuntimeError – if KV-caches are not setup. Use
setup_caches()to setup caches first.
- set_num_output_chunks(num_output_chunks: int) None[source]¶
Used to save memory in combination with
CEWithChunkedOutputLoss. This should be called before the first forward pass, in the recipe.
- setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]¶
- Sets up key-value attention caches for inference. For each layer in
self.layers: TransformerSelfAttentionLayerwill usedecoder_max_seq_len.TransformerCrossAttentionLayerwill useencoder_max_seq_len.FusionLayerwill usedecoder_max_seq_lenandencoder_max_seq_len.
- Sets up key-value attention caches for inference. For each layer in