非严格追踪编程模型#
创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日
摘要
非严格追踪 是一种追踪 Python 代码的方式,它比 Dynamo 更宽松,但可能导致静默的错误。
非严格追踪会运行一个 Python 函数,并利用 Python 和 PyTorch 的操作符重载能力,记录执行过程中发生的 Tensor 操作,生成一个追踪。
如果一个函数符合某些约束条件,那么它是非严格可追踪的,即该函数是纯函数,并且不直接操作 Tensor.data_ptr()。
非严格追踪可能会特化某些变量,并将它们视为常量,将变量的值“烘焙”到追踪中。
torch.compile
的内部组件(make_fx
, AOTDispatcher)使用非严格追踪。 torch._dynamo.nonstrict_trace
也可以在 torch.compile
的代码中使用,以标记需要使用非严格追踪进行追踪的代码段。非严格追踪会运行一个 Python 函数,并利用 Python 和 PyTorch 的操作符重载能力,记录执行过程中发生的 Tensor 操作,生成一个追踪。
make_fx
是非严格追踪的主要入口点。对于以下函数,在输入运行时只执行顶部的分支,因此它捕获的图只包含该分支。
from torch.fx.experimental.proxy_tensor import make_fx
def f(x):
if x.shape[0] > 2:
return x ** 2 / 6
else:
return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
def forward(self, x_1: "f32[3]"):
# No stacktrace found for following nodes
pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None
div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6); pow_1 = None
return div
'class f(torch.nn.Module):\n def forward(self, x_1: "f32[3]"):\n # No stacktrace found for following nodes\n pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None\n div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6); pow_1 = None\n return div\n '
非严格追踪与 Dynamo(严格)追踪的区别在于它是不安全的,也就是说,对于一个给定的函数,它捕获的 Tensor 操作图可能与原始函数具有不同的语义。对于一个 Python 函数,Dynamo 追踪会捕获 Tensor 操作图和剩余的字节码,它们的组合与 Python 函数具有相同的语义。
纯函数#
非严格追踪仅在纯函数上是可靠的,因此只有纯函数才应该进行非严格追踪。
纯函数是具有以下属性的函数:
确定性。 对于相同的输入,纯函数将始终返回相同的输出。
无副作用。 纯函数没有任何副作用,例如修改外部状态或执行 I/O 操作。
显式的输入/输出。 所有输入数据都必须通过函数参数传递,并且所有输出都从函数中返回。
以下是一些非纯函数的示例,在这些函数中,捕获的图与原始函数行为不同。
示例 1:无显式输入(例如,访问全局 Tensor)#
var = torch.tensor(1)
def function_with_global_access(y):
return y + var
x = torch.tensor([0, 1, 2])
# _allow_non_fake_inputs=True is needed to capture the global variable
# for demonstration purposes.
gm = make_fx(
function_with_global_access, tracing_mode="fake", _allow_non_fake_inputs=True
)(x)
# Non-strict Tracing captures the value of the global (1.)
print("1. call function", function_with_global_access(x))
print("1. call graph", gm(x))
# However, after changing the global, the captured graph
# produces a different result from the original function
var = torch.tensor(2)
print("2. call function", function_with_global_access(x))
print("2. call graph", gm(x))
# To capture a graph that can have a varying `var` tensor,
# it must be an explicit input:
def function_fixed(y, var):
return y + var
var = torch.tensor(3)
gm = make_fx(function_fixed, tracing_mode="fake")(x, var)
print("3. call function", function_fixed(x, var))
print("3. call graph", gm(x, var))
var = torch.tensor(4)
print("4. call function", function_fixed(x, var))
print("4. call graph", gm(x, var))
1. call function tensor([1, 2, 3])
1. call graph tensor([1, 2, 3])
2. call function tensor([2, 3, 4])
2. call graph tensor([1, 2, 3])
3. call function tensor([3, 4, 5])
3. call graph tensor([3, 4, 5])
4. call function tensor([4, 5, 6])
4. call graph tensor([4, 5, 6])
有关原因,请参阅 特化和常量。
示例 2:副作用(打印)#
def function_with_side_effect(y):
print(y)
x = torch.tensor([0, 1, 2])
_ = function_with_side_effect(x)
tensor([0, 1, 2])
在 Python 中运行 f
会作为副作用打印一个 Tensor。
gm = make_fx(function_with_side_effect, tracing_mode="fake")(x)
FakeTensor(..., size=(3,), dtype=torch.int64)
在非严格追踪期间,此打印发生在图捕获过程中。
_ = gm(x)
图不存储对 print
语句的调用,因此执行图不会打印任何内容。
示例 3:副作用(列表突变)#
lst = []
def function_with_input_list_mutation(lst):
val = lst.pop()
return val
x = torch.tensor([0, 1, 2])
y = torch.tensor([0, 1, 2])
# Each time the function is executed, the list shrinks in size
lst = [x, y]
function_with_input_list_mutation(lst)
print("len(lst) after one call", len(lst))
function_with_input_list_mutation(lst)
print("len(lst) after two calls", len(lst))
# With Non-strict Tracing, the length of the list shrinks during
# the graph capture but not in invocations of the graph.
lst = [x, y]
gm = make_fx(function_with_input_list_mutation, tracing_mode="fake")(lst)
print("len(lst) after graph capture", len(lst))
gm(lst)
print("len(lst) after one call to graph", len(lst))
gm(lst)
print("len(lst) after two calls to graph", len(lst))
len(lst) after one call 1
len(lst) after two calls 0
len(lst) after graph capture 2
len(lst) after one call to graph 2
len(lst) after two calls to graph 2
无直接 data_ptr 操作#
直接操作 Tensor.data_ptr
是不可非严格追踪的。其背后的直觉是,PyTorch 无法知道您是如何操作 data_ptr
的。
import ctypes
# Create a tensor with a single element
tensor = torch.tensor([42], dtype=torch.int32) # Using int32 for simplicity
def function_with_data_ptr(tensor):
# Get the data pointer
ptr = tensor.data_ptr()
# Cast the pointer to a ctypes pointer
ctypes_ptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_int32))
# Increment the value at the pointer
ctypes_ptr.contents.value += 1
return tensor
try:
make_fx(function_with_data_ptr, tracing_mode="fake")(tensor)
except Exception as e:
print(e)
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html
特化和常量#
非严格追踪捕获的图可能在某些值上进行了特化。这意味着捕获的图仅对这些值有效。我们说该图将这些值视为常量。
在非严格追踪期间,所有非 Tensor 变量都被视为常量。
def f(x, y):
return x + y
x = torch.tensor([0, 1, 2])
y = 3.14
gm = make_fx(f, tracing_mode="fake")(x, y)
gm.print_readable()
class f(torch.nn.Module):
def forward(self, x_1: "i64[3]", y_1):
# No stacktrace found for following nodes
add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14); x_1 = None
return add
'class f(torch.nn.Module):\n def forward(self, x_1: "i64[3]", y_1):\n # No stacktrace found for following nodes\n add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14); x_1 = None\n return add\n '
3.14 是图中的一个常量。
非严格追踪还会对输入 Tensor 的属性进行特化。
def f(x):
if x.shape[0] > 2:
return x ** 2 / 6
else:
return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
def forward(self, x_1: "f32[3]"):
# No stacktrace found for following nodes
pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None
div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6); pow_1 = None
return div
'class f(torch.nn.Module):\n def forward(self, x_1: "f32[3]"):\n # No stacktrace found for following nodes\n pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None\n div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6); pow_1 = None\n return div\n '
它还会对未直接传递给函数的任何变量进行特化。
var = torch.tensor(1)
def f(x):
return x + y
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
def forward(self, x_1: "f32[3]"):
# No stacktrace found for following nodes
add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14); x_1 = None
return add
'class f(torch.nn.Module):\n def forward(self, x_1: "f32[3]"):\n # No stacktrace found for following nodes\n add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14); x_1 = None\n return add\n '