Short name describing what triggered the graph break
Attempted to use torch.nn.Parameter() constructor with Dynamo
Values or code snippet captured at the break point
Explanation of why the graph break was triggered
Dynamo does not support this
Hints on how to resolve the graph break
torch.nn.Parameter() outside the compiled region.graph_break_on_nn_param_ctor offExample code that causes this graph break:
import torch
@torch.compile(fullgraph=True)
def fn(x):
param = torch.nn.Parameter(torch.randn_like(x))
return param, param + x
fn(torch.randn(3, 3))
Try to construct nn.Parameter() outside the compiled region.
@torch.compile(fullgraph=True)
def fn(x, param):
return param + x
fn(torch.randn(3, 3), torch.nn.Parameter(torch.randn(3, 3)))
If this is not possible, turn torch._dynamo.config.graph_break_on_nn_param_ctor off (NOT RECOMMENDED):
torch._dynamo.config.graph_break_on_nn_param_ctor = False
@torch.compile(fullgraph=True)
def fn(x):
param = torch.nn.Parameter(torch.randn_like(x))
return param, param + x
fn(torch.randn(3, 3))