torch.func.vmap#
- torch.func.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[源代码]#
vmap 是矢量化映射;
vmap(func)
返回一个新函数,该函数将func
映射到输入的某些维度上。从语义上讲,vmap 将 map 推入func
调用的 PyTorch 操作中,从而有效地将这些操作向量化。vmap 对于处理批处理维度很有用:您可以编写一个在单个示例上运行的函数
func
,然后使用vmap(func)
将其提升为一个可以接受示例批次的函数。vmap 还可以与 autograd 组合以计算批处理梯度。注意
torch.vmap()
为了方便起见,别名为torch.func.vmap()
。你可以随意使用其中一个。- 参数
func (function) – 接受一个或多个参数的 Python 函数。必须返回一个或多个 Tensor。
in_dims (int 或 嵌套结构) – 指定输入应该在哪一个维度上进行映射。
in_dims
的结构应该与输入相匹配。如果某个输入的in_dim
为 None,则表示没有映射维度。默认为 0。out_dims (int 或 Tuple[int]) – 指定映射维度应该出现在输出的哪个位置。如果
out_dims
是一个 Tuple,那么它应该为每个输出包含一个元素。默认为 0。randomness (str) – 指定此 vmap 中的随机性在批次之间是相同还是不同。如果为 ‘different’,则每个批次的随机性将不同。如果为 ‘same’,则批次之间的随机性将相同。如果为 ‘error’,则对随机函数的任何调用都将报错。默认为 ‘error’。警告:此标志仅适用于 PyTorch 的随机操作,不适用于 Python 的 random 模块或 numpy 随机性。
chunk_size (None 或 int) – 如果为 None(默认),则在输入上应用单个 vmap。如果非 None,则一次计算
chunk_size
个样本的 vmap。请注意,chunk_size=1
等同于使用 for 循环计算 vmap。如果您在计算 vmap 时遇到内存问题,请尝试使用非 None 的 chunk_size。
- 返回
返回一个新的“批处理”函数。它接受与
func
相同的输入,只是每个输入的指定in_dims
索引处会多出一个维度。它返回与func
相同的输出,只是每个输出的指定out_dims
索引处会多出一个维度。- 返回类型
使用
vmap()
的一个例子是计算批量点积。PyTorch 不提供批量torch.dot
API;与其徒劳地翻阅文档,不如使用vmap()
来构造一个新函数。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()
有助于隐藏批次维度,从而提供更简单的模型编写体验。>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()
还可以帮助向量化以前难以或不可能进行批处理的计算。一个例子是高阶梯度计算。PyTorch 的 autograd 引擎计算 vjps(向量-雅可比乘积)。通常需要 N 次调用autograd.grad
(每调用一次对应一个雅可比矩阵的行)才能计算某个函数 f: R^N -> R^N 的完整雅可比矩阵。使用vmap()
,我们可以向量化整个计算,在一次autograd.grad
调用中计算雅可比矩阵。>>> # Setup >>> N = 5 >>> f = lambda x: x**2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()
也可以嵌套使用,产生一个具有多个批处理维度的输出。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap( ... torch.vmap(torch.dot) ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
如果输入不在第一个维度上进行批处理,
in_dims
会指定每个输入的批处理维度为>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot( ... x, y ... ) # output is [5] instead of [2] if batched along the 0th dimension
如果存在多个输入,且每个输入在不同维度上进行批处理,
in_dims
必须是一个元组,其中包含每个输入的批处理维度为>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot( ... x, y ... ) # second arg doesn't have a batch dim because in_dim[1] was None
如果输入是 Python 结构,
in_dims
必须是一个元组,其中包含一个与输入形状匹配的结构。>>> f = lambda dict: torch.dot(dict["x"], dict["y"]) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {"x": x, "y": y} >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},)) >>> batched_dot(input)
默认情况下,输出在第一个维度上进行批处理。但是,可以使用
out_dims
在任何维度上进行批处理。>>> f = lambda x: x**2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
对于任何使用 kwargs 的函数,返回的函数不会批处理 kwargs,但会接受 kwargs。
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
注意
vmap 不提供开箱即用的通用自动批处理或处理可变长度序列。