TransformerDecoder¶
- class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]¶
Transformer Decoder derived from the Llama2 architecture.
- Parameters:
tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int) – maximum sequence length the model will be run with, as used by
KVCache()num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the
KVCache()head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the
KVCache()norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.
output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.
num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.
output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output
- Raises:
AssertionError – num_layers is set and layer is a list
AssertionError – num_layers is not set and layer is an nn.Module
Note
Arg values are checked for correctness (eg:
attn_dropoutbelongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.- caches_are_enabled() bool[source]¶
Check if the key value caches are setup. This is useful to efficient inference.
- chunked_output(last_hidden_state: Tensor) List[Tensor][source]¶
Apply output projection in chunks. This should be applied in conjunction with
CEWithChunkedOutputLossas upcasting to fp32 is done there.To use this method, you should first call
set_num_output_chunks().- Parameters:
last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].
- Returns:
- List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
- Return type:
List[torch.Tensor]
- forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][source]¶
- Parameters:
tokens (torch.Tensor) – input tensor with shape
[b x s]mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:
A boolean tensor with shape
[b x s x s],[b x s x self.encoder_max_cache_seq_len], or[b x s x self.encoder_max_cache_seq_len]if using KV-cacheing with encoder/decoder layers. A value of True in rowiand columnjmeans tokeniattends to tokenj. A value of False means tokenidoes not attend to tokenj. If no mask is specified, a causal mask is used by default.A
BlockMaskfor document masking in a packed sequence created via create_block_mask. We useflex_attention()when computing attention with block masks. Default is None.encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape
[b x s_e x d_e]encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position
i,jmeans tokenican attend to embeddingjin the decoder. Mask has shape[b x s x s_e]. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape
[b x s]. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.
- Returns:
- output tensor with shape
[b x s x v]or a list of layer output tensors defined by
output_hidden_stateswith the final output tensor appended to the list.
- output tensor with shape
- Return type:
Union[torch.Tensor, List[torch.Tensor]]
Note
At the very first step of inference, when the model is provided with a prompt,
input_posshould contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will betorch.arange(prompt_length). For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape[b, padded_prompt_length]. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches,input_poswill contain the position(s) of the current token(s)torch.tensor([padded_prompt_length]). Otherwise,input_poswill contain all the position ids up to the current token.- Shape notation:
b: batch size
s: token sequence length
s_e: encoder sequence length
v: vocab size
d: token embed dim
d_e: encoder embed dim
m_s: max seq len
- set_num_output_chunks(num_output_chunks: int) None[source]¶
Used to save memory in combination with
CEWithChunkedOutputLoss. This should be called before the first forward pass, in the recipe.
- setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]¶
Sets up key-value attention caches for inference. For each layer in
self.layers: -TransformerSelfAttentionLayerwill usedecoder_max_seq_len. -TransformerCrossAttentionLayerwill useencoder_max_seq_len. -FusionLayerwill use bothdecoder_max_seq_lenandencoder_max_seq_len.