.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_apps_lightning_classy_vision_interpret.py: Model Interpretability App Example ============================================= This is an example TorchX app that uses captum to analyze inputs to for model interpretability purposes. It consumes the trained model from the trainer app example and the preprocessed examples from the datapreproc app example. The output is a series of images with integrated gradient attributions overlayed on them. See https://captum.ai/tutorials/CIFAR_TorchVision_Interpret for more info on using captum. .. code-block:: default import argparse import itertools import os.path import sys import tempfile from typing import List import fsspec import torch # ensure data and module are on the path sys.path.append(".") from torchx.examples.apps.lightning_classy_vision.data import ( TinyImageNetDataModule, download_data, create_random_data, ) from torchx.examples.apps.lightning_classy_vision.model import TinyImageNetModel # FIXME: captum must be imported after torch otherwise it causes python to crash if True: import numpy as np from captum.attr import IntegratedGradients from captum.attr import visualization as viz def parse_args(argv: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser(description="example TorchX captum app") parser.add_argument( "--load_path", type=str, help="checkpoint path to load model weights from", required=True, ) parser.add_argument( "--data_path", type=str, help="path to load the training data from, if not provided, random dataset will be created", ) parser.add_argument( "--output_path", type=str, help="path to place analysis results", required=True, ) return parser.parse_args(argv) def convert_to_rgb(arr: torch.Tensor) -> np.ndarray: # pyre-ignore[24] """ This converts the image from a torch tensor with size (1, 1, 64, 64) to numpy array with size (64, 64, 3). """ out = arr.squeeze().swapaxes(0, 2) assert out.shape == (64, 64, 3), "invalid shape produced" return out.numpy() def main(argv: List[str]) -> None: with tempfile.TemporaryDirectory() as tmpdir: args = parse_args(argv) # Init our model model = TinyImageNetModel() print(f"loading checkpoint: {args.load_path}...") model.load_from_checkpoint(checkpoint_path=args.load_path) # Download and setup the data module if not args.data_path: data_path = os.path.join(tmpdir, "data") os.makedirs(data_path) create_random_data(data_path) else: data_path = download_data(args.data_path, tmpdir) data = TinyImageNetDataModule( data_dir=data_path, batch_size=1, ) ig = IntegratedGradients(model) data.setup("test") dataloader = data.test_dataloader() # process first 5 images for i, (input, label) in enumerate(itertools.islice(dataloader, 5)): print(f"analyzing example {i}") # input = input.unsqueeze(dim=0) model.zero_grad() attr_ig, delta = ig.attribute( input, target=label, baselines=input * 0, return_convergence_delta=True, ) if attr_ig.count_nonzero() == 0: # Our toy model sometimes has no IG results. print("skipping due to zero gradients") continue fig, axis = viz.visualize_image_attr( convert_to_rgb(attr_ig), convert_to_rgb(input), method="blended_heat_map", sign="all", show_colorbar=True, title="Overlayed Integrated Gradients", ) out_path = os.path.join(args.output_path, f"ig_{i}.png") print(f"saving heatmap to {out_path}") with fsspec.open(out_path, "wb") as f: fig.savefig(f) if __name__ == "__main__" and "NOTEBOOK" not in globals(): main(sys.argv[1:]) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_examples_apps_lightning_classy_vision_interpret.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: interpret.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: interpret.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_