Short name describing what triggered the graph break
autograd.grad with output that requires grad
Values or code snippet captured at the break point
context
Explanation of why the graph break was triggered
The compiled function uses torch.autograd.grad() and returns a tensor that still requires gradients and is connected to the autograd.grad() computation. This would cause aot_autograd to attempt ‘backward through graph a second time’, which is not supported.
Hints on how to resolve the graph break
return loss.detach(), grads instead of return loss, grads.torch.compile call, or by using torch.compiler.set_stance("force_eager").