torch.autograd.Function.vmap#
- static Function.vmap(info, in_dims, *args)[源代码]#
定义此 autograd.Function 在
torch.vmap()下的行为。要使
torch.autograd.Function()支持torch.vmap(),您必须覆盖此静态方法,或者将generate_vmap_rule设置为True(您不能同时执行这两项)。如果您选择重写此静态方法:它必须接受
第一个参数是一个
info对象。info.batch_size指定了要 vmap 的维度的大小,而info.randomness是传递给torch.vmap()的随机性选项。第二个参数是一个
in_dims元组。对于args中的每个 arg,in_dims有一个相应的Optional[int]。如果 arg 不是 Tensor 或 arg 不被 vmap,则为None,否则,它是一个指定 Tensor 的哪个维度被 vmap 的整数。*args,与forward()的 args 相同。
vmap 静态方法的返回值是一个元组
(output, out_dims)。与in_dims类似,out_dims的结构应与output相同,并且每个输出都包含一个out_dim,指定输出是否具有 vmap 的维度以及在该维度中的索引。有关更多详细信息,请参阅 使用 autograd.Function 扩展 torch.func。