在哪里应用 torch.compile?#
创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日
我们建议将 torch.compile
应用于不会导致过度问题的最高级别函数。通常情况下,它是
您的
train
或eval
步骤,包含优化器但不包含循环,您的顶级
nn.Module
或一些子
nn.Module
。
torch.compile
尤其不擅长处理 DDP 或 FSDP 等分布式包装器模块,因此请考虑将 torch.compile
应用于传递给包装器的内部模块。
# inference
model = ...
model.compile()
for _ in range(N_ITERS):
inp = ...
out = model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())
@torch.compile
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
for _ in range(N_ITERS):
inp = ...
train(model, inp)
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)
for _ in range(N_ITERS):
inp = ...
out = model_ddp(inp)
compile(model)
与 model.compile()
#
由于 torch.compile
与 nn.Module
实例的交互方式存在细微差别,因此如果您希望将 nn.Module
实例作为顶级函数进行编译,我们建议使用 nn.Module
实例的 .compile()
方法。嵌套的模块调用将被正确跟踪 - 在这种情况下无需调用 .compile()
。
# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)
# DO THIS
model = MyModel()
model.compile()
model(inp)
# this is also acceptable
@torch.compile
def fn(model, inp):
return model(inp)
model = MyModel()
fn(model, inp)