评价此页

torch.autograd.Function.vmap#

static Function.vmap(info, in_dims, *args)[源代码]#

定义此 autograd.Function 在 torch.vmap() 下的行为。

要使 torch.autograd.Function() 支持 torch.vmap(),您必须覆盖此静态方法,或者将 generate_vmap_rule 设置为 True(您不能同时执行这两项)。

如果您选择重写此静态方法:它必须接受

  • 第一个参数是一个 info 对象。info.batch_size 指定了要 vmap 的维度的大小,而 info.randomness 是传递给 torch.vmap() 的随机性选项。

  • 第二个参数是一个 in_dims 元组。对于 args 中的每个 arg,in_dims 有一个相应的 Optional[int]。如果 arg 不是 Tensor 或 arg 不被 vmap,则为 None,否则,它是一个指定 Tensor 的哪个维度被 vmap 的整数。

  • *args,与 forward() 的 args 相同。

vmap 静态方法的返回值是一个元组 (output, out_dims)。与 in_dims 类似,out_dims 的结构应与 output 相同,并且每个输出都包含一个 out_dim,指定输出是否具有 vmap 的维度以及在该维度中的索引。

有关更多详细信息,请参阅 使用 autograd.Function 扩展 torch.func