GB0007

Graph-Break Type

Short name describing what triggered the graph break

Attempted to call function marked as skipped

Context

Values or code snippet captured at the break point

module: {module_name}, qualname: {qualname}, skip reason: {reason}

Explanation

Explanation of why the graph break was triggered

explanation

Hints

Hints on how to resolve the graph break

No hints provided.

Additional Information

Example code that causes the graph break is:

import torch
from tqdm import tqdm


def fn(x):
    for i in tqdm(range(5)):
        x += i
    return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))

The first workaround is to remove the skipped function:

def fn(x):
    for i in range(5):
        x += i
    return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))

You can use torch.compiler.is_compiling() if you only want to remove the function when torch.compile is active:

def fn(x):
    iter = range(5)
    if not torch.compiler.is_compiling():
        iter = tqdm(iter)
    for i in iter:
        x += i
    return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))

The second workaround is to not compile the skipped function:

@torch.compile(fullgraph=True)
def inner(x, i):
    x += i

def fn(x):
    for i in tqdm(range(5)):
        inner(x, i)
    return x

fn(torch.randn(3))

The third workaround is to override Dynamo’s default skipping behavior using the @torch._dynamo.dont_skip_tracing decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function’s internals are also untraceable so proceed with caution.

@torch._dynamo.dont_skip_tracing
def fn(x):
    for i in tqdm(range(5)):
        x += i
    return x


compiled_fn = torch.compile(fn, fullgraph=True)
# Another graph break because we attempted to trace into `tqdm.__new__`
compiled_fn(torch.randn(3))

If you are attempting to call a logging function (e.g. _warnings.warn), you can try adding it to torch._dynamo.config.reorderable_logging_functions:

import warnings

torch._dynamo.config.reorderable_logging_functions.add(warnings.warn)

def fn(x):
    warnings.warn("warning")
    for i in range(5):
        x += i
    return x


compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))

Click here to add Additional Info

Back to Registry