评价此页
fullgraph=False">

在哪里应用 torch.compile?#

创建时间:2025 年 7 月 28 日 | 最后更新时间:2025 年 7 月 28 日

我们建议将 torch.compile 应用于不会导致过度问题的最高级别函数。通常情况下,它是

  • 您的 traineval 步骤,包含优化器但不包含循环,

  • 您的顶级 nn.Module

  • 或一些子 nn.Module

torch.compile 尤其不擅长处理 DDP 或 FSDP 等分布式包装器模块,因此请考虑将 torch.compile 应用于传递给包装器的内部模块。

# inference
model = ...
model.compile()

for _ in range(N_ITERS):
    inp = ...
    out = model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

compile(model)model.compile()#

由于 torch.compilenn.Module 实例的交互方式存在细微差别,因此如果您希望将 nn.Module 实例作为顶级函数进行编译,我们建议使用 nn.Module 实例的 .compile() 方法。嵌套的模块调用将被正确跟踪 - 在这种情况下无需调用 .compile()

# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)

# DO THIS
model = MyModel()
model.compile()
model(inp)

# this is also acceptable
@torch.compile
def fn(model, inp):
    return model(inp)
model = MyModel()
fn(model, inp)