Short name describing what triggered the graph break
Attempted to call function marked as skipped
Values or code snippet captured at the break point
module: {module_name}, qualname: {qualname}, skip reason: {reason}
Explanation of why the graph break was triggered
explanation
Hints on how to resolve the graph break
No hints provided.
Example code that causes the graph break is:
import torch
from tqdm import tqdm
def fn(x):
for i in tqdm(range(5)):
x += i
return x
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
The first workaround is to remove the skipped function:
def fn(x):
for i in range(5):
x += i
return x
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
You can use torch.compiler.is_compiling() if you only want to remove the function when torch.compile is active:
def fn(x):
iter = range(5)
if not torch.compiler.is_compiling():
iter = tqdm(iter)
for i in iter:
x += i
return x
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))
The second workaround is to not compile the skipped function:
@torch.compile(fullgraph=True)
def inner(x, i):
x += i
def fn(x):
for i in tqdm(range(5)):
inner(x, i)
return x
fn(torch.randn(3))
The third workaround is to override Dynamo’s default skipping behavior using the @torch._dynamo.dont_skip_tracing decorator. NOTE: This is an advanced feature and may lead to further graph breaks if the function’s internals are also untraceable so proceed with caution.
@torch._dynamo.dont_skip_tracing
def fn(x):
for i in tqdm(range(5)):
x += i
return x
compiled_fn = torch.compile(fn, fullgraph=True)
# Another graph break because we attempted to trace into `tqdm.__new__`
compiled_fn(torch.randn(3))
If you are attempting to call a logging function (e.g. _warnings.warn), you can try adding it to torch._dynamo.config.reorderable_logging_functions:
import warnings
torch._dynamo.config.reorderable_logging_functions.add(warnings.warn)
def fn(x):
warnings.warn("warning")
for i in range(5):
x += i
return x
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn(torch.randn(3))