Shortcuts

Source code for torchtune.modules.transforms._transforms

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Mapping, Optional, Protocol

import torch


[docs]class Transform(Protocol): """ Loose interface for all data and model transforms. Transforms operate at the sample level and perform operations on a sample dict, returning the updated dict. For an example implementation of this protocol, see :class:`~torchtune.modules.transforms.VisionCrossAttentionMask`. """ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: pass
[docs]class VisionCrossAttentionMask(Transform): """ Computes the cross-attention mask for text + image inputs. Text tokens that participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198): (1) Text tokens immediately following the image token up until the next image token (2) Consecutive image tokens attend to subsequent text tokens :: ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ <img1> <img2>These are two dogs. <img3> This is a cat. Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), where True indicates that the token outputted from the image encoder attends to the token in the text sequence in cross-attention. A list of these masks are returned with length equal to number of images in the sample. Args: tile_size (int): The size of the image tiles from the image transform patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. image_token_id (int): Token ID of the image special token. max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to pad mask during inference. Defaults to None """ def __init__( self, tile_size: int, patch_size: int, image_token_id: int, max_num_tiles: Optional[int] = None, ): patch_grid_size = tile_size // patch_size self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id self.max_num_tiles = max_num_tiles def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: """ Returns a list of lists of the form [start, end) where start is the index of the current image token and end is the index of the next image token, exclusive. Args: tokens (List[int]): List of token IDs in the text sequence Returns: List[List[int]]: List of lists of the form [start, end) indicating range of positions in text sequence that should attend to the image Example: >>> text = "<img1><img2>These are two dogs. <img3>This is a cat." >>> image_token_id = 1 >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) >>> intervals = transform._get_image_attention_intervals(tokens) >>> print(intervals) [[0, 7], [1, 7], [7, 12]] """ end = len(tokens) vision_token_locations = [ i for i, token in enumerate(tokens) if token == self.image_token_id ] # Return empty list if there are no images if len(vision_token_locations) == 0: return [] # If there is only one image, it will attend to subsequent text until end if len(vision_token_locations) == 1: return [[vision_token_locations[0], end]] # Construct intervals from previous image token to next image token vision_masks = [ [tok_idx_prev, tok_idx_next] # Offset by one to get consecutive indices for tok_idx_prev, tok_idx_next in zip( vision_token_locations[:-1], vision_token_locations[1:] ) ] # Last image will attend to subsequent text until end vision_masks.append([vision_token_locations[-1], end]) # If there are consecutive vision tokens, they should all attend to the # same subsequent text last_mask_end = vision_masks[-1][1] for vision_mask in vision_masks[::-1]: if vision_mask[0] == vision_mask[1] - 1: vision_mask[1] = last_mask_end last_mask_end = vision_mask[1] return vision_masks def __call__( self, sample: Mapping[str, Any], inference: bool = False ) -> Mapping[str, Any]: """ Generates the vision cross-attention mask for the given sample based on the image token locations interleaved in the text sequence. Args: sample (Mapping[str, Any]): Sample dict containing the following keys: - tokens (List[int]): List of token IDs in the text sequence. Number of image token IDs in the sequence must match the number of images. - images (List[torch.Tensor]): List of image Tensors post-tiling of shape (n_tiles, c, h, w) each. inference (bool): Whether the template is being used for inference or not. Returns: Mapping[str, Any]: sample with a new key encoder_mask, with a mask per image with shape (text_seq_len, image_seq_len) where text_seq_len == len(tokens) and image_seq_len == max_tiles * (patches_per_tile + 1). These masks get padded and concatenated in the batch collator. Raises: RuntimeError: if the number of images in the batch does not match the number of image tokens in the batch. """ tokens, images = sample["tokens"], sample["encoder_input"]["images"] # One sample can have multiple images - verify the number of image tokens # is the same n_img = len(images) intervals = self._get_image_attention_intervals(tokens) if len(intervals) != n_img: raise RuntimeError( f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." ) # Create mask for each individual image based on its number of tokens, # which can vary based on number of tiles since they are not yet tile padded. # The masks are padded and concatenated together in the batch collator text_seq_len = len(tokens) max_image_size = None if inference and self.max_num_tiles is not None: max_image_size = self.max_num_tiles * (self.patches_per_tile + 1) masks = [] for image_num, interval in enumerate(intervals): # Identify what part of text sequence should be attended start, end = interval # Compute this image's number of tokens based on num tiles, patches per tile n_tiles = images[image_num].shape[0] image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token # Mask will be block of 1s at the corresponding interval in the text. # It is not a causal block because all the image tokens correspond # to a single image, so text tokens attend to all the image's tokens. # The mask is text_seq_len x mask_image_size if defined, otherwise # it uses current text/image sequence lengths. mask = torch.zeros( text_seq_len, max_image_size or image_seq_len, dtype=torch.bool ) mask[start:end, :image_seq_len] = True masks.append(mask) sample.update({"encoder_mask": masks}) return sample

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources