注意
转到页面底部 下载完整示例代码。
理解 requires_grad、retain_grad、叶子张量和非叶子张量#
作者: Justin Silver
本教程通过一个简单的示例,解释了 requires_grad、retain_grad、叶子张量和非叶子张量之间的细微差别。
在开始之前,请确保你了解张量及其操作方法。对自动求导(autograd)工作原理的基本了解也会有所帮助。
设置#
首先,确保已安装 PyTorch,然后导入必要的库。
import torch
import torch.nn.functional as F
接下来,我们实例化一个简单的网络以专注于梯度。这将是一个仿射层(affine layer),后接一个 ReLU 激活函数,最后计算预测张量和标签张量之间的 MSE 损失。
注意,参数(W 和 b)必须设置 requires_grad=True,这样 PyTorch 才能跟踪涉及这些张量的操作。我们将在之后的章节中详细讨论这一点。
# tensor setup
x = torch.ones(1, 3) # input with shape: (1, 3)
W = torch.ones(3, 2, requires_grad=True) # weights with shape: (3, 2)
b = torch.ones(1, 2, requires_grad=True) # bias with shape: (1, 2)
y = torch.ones(1, 2) # output with shape: (1, 2)
# forward pass
z = (x @ W) + b # pre-activation with shape: (1, 2)
y_pred = F.relu(z) # activation with shape: (1, 2)
loss = F.mse_loss(y_pred, y) # scalar loss
叶子张量 vs. 非叶子张量#
运行前向传播后,PyTorch 的 autograd 构建了一个动态计算图,如下所示。这是一个有向无环图 (DAG),它记录了输入张量(叶子节点)、这些张量的所有后续操作以及中间/输出张量(非叶子节点)。该图使用微积分中的链式法则,从图的根节点(输出)到叶子节点(输入)来计算每个张量的梯度。
graph TD
x["x<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"]
W["W<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
b["b<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
matmul["x @ W"]
z["z = x @ W + b<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
relu["y_pred = relu(z)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
y["y<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"]
loss["loss = mse(y_pred, y)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
x --> matmul
W --> matmul
matmul --> z
b --> z
z --> relu
relu --> loss
y --> loss
如果一个节点不是由至少一个输入且 requires_grad=True 的张量运算产生的(例如 x, W, b 和 y),PyTorch 将其视为叶子(leaf),其他所有节点均视为非叶子(non-leaf)(例如 z, y_pred 和 loss)。你可以通过探测张量的 is_leaf 属性来以编程方式验证这一点。
x.is_leaf=True
z.is_leaf=False
叶子张量和非叶子张量之间的区别决定了张量的梯度在反向传播后是否会存储在 grad 属性中,从而能否用于梯度下降。我们将在下一节中对此进行更多介绍。
现在让我们研究 PyTorch 如何为其计算图中的张量计算和存储梯度。
requires_grad#
为了构建可用于梯度计算的计算图,我们需要在张量构造函数中传入 requires_grad=True 参数。默认情况下,该值为 False,因此 PyTorch 不会跟踪任何已创建张量的梯度。要验证这一点,尝试不设置 requires_grad,重新运行前向传播,然后执行反向传播。你将会看到:
>>> loss.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
此错误意味着 autograd 无法反向传播到任何叶子张量,因为 loss 没有在跟踪梯度。如果你需要更改此属性,可以在张量上调用 requires_grad_()(注意末尾的下划线)。
我们可以像上面使用 is_leaf 属性一样,对哪些节点需要进行梯度计算进行完整性检查。
x.requires_grad=False
W.requires_grad=True
z.requires_grad=True
记住这一点很有用:非叶子张量默认具有 requires_grad=True,否则反向传播将会失败。如果张量是叶子张量,则只有在用户明确设置的情况下,它才会具有 requires_grad=True。另一种表达方式是:如果张量的至少一个输入需要梯度,那么该张量也将需要梯度。
此规则有两个例外情况:
总之,requires_grad 告诉 autograd 哪些张量需要计算梯度,以便反向传播能够正常工作。这与哪些张量的 grad 字段已被填充是不同的,后者是下一节的主题。
retain_grad#
为了实际执行优化(例如 SGD、Adam 等),我们需要运行反向传播以便提取梯度。
调用 backward() 会填充所有 requires_grad=True 的叶子张量的 grad 字段。grad 是损失相对于我们正在探测的张量的梯度。在运行 backward() 之前,此属性设置为 None。
W.grad=tensor([[3., 3.],
[3., 3.],
[3., 3.]])
b.grad=tensor([[3., 3.]])
你可能对网络中的其他张量感到好奇。让我们检查剩余的叶子节点:
x.grad=None
y.grad=None
这些张量的梯度尚未填充,因为我们没有明确告知 PyTorch 计算它们的梯度(requires_grad=False)。
现在让我们查看一个中间非叶子节点:
print(f"{z.grad=}")
/var/lib/workspace/beginner_source/understanding_leaf_vs_nonleaf_tutorial.py:230: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:494.)
print(f"{z.grad=}")
z.grad=None
PyTorch 为梯度返回 None,并警告我们正在访问非叶子节点的 grad 属性。虽然 autograd 必须计算中间梯度以使反向传播正常工作,但它假定你之后不需要访问这些值。要改变这种行为,我们可以在张量上使用 retain_grad() 函数。这会告诉 autograd 引擎在调用 backward() 后填充该张量的 grad。
# we have to re-run the forward pass
z = (x @ W) + b
y_pred = F.relu(z)
loss = F.mse_loss(y_pred, y)
# tell PyTorch to store the gradients after backward()
z.retain_grad()
y_pred.retain_grad()
loss.retain_grad()
# have to zero out gradients otherwise they would accumulate
W.grad = None
b.grad = None
# backpropagation
loss.backward()
# print gradients for all tensors that have requires_grad=True
print(f"{W.grad=}")
print(f"{b.grad=}")
print(f"{z.grad=}")
print(f"{y_pred.grad=}")
print(f"{loss.grad=}")
W.grad=tensor([[3., 3.],
[3., 3.],
[3., 3.]])
b.grad=tensor([[3., 3.]])
z.grad=tensor([[3., 3.]])
y_pred.grad=tensor([[3., 3.]])
loss.grad=tensor(1.)
我们得到的 W.grad 与之前相同。还要注意,由于损失是标量,损失相对于其自身的梯度简单地为 1.0。
如果我们现在查看计算图的状态,可以看到中间张量的 retains_grad 属性已经改变。按照惯例,对于任何叶子节点,即使它需要梯度,此属性也会显示为 False。
graph TD
x["x<br/>is_leaf=True<br/>requires_grad=False<br/>retains_grad=False<br/>grad=None"]
W["W<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=torch.Tensor"]
b["b<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=torch.Tensor"]
matmul["x @ W"]
z["z = x @ W + b<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"]
relu["y_pred = relu(z)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"]
y["y<br/>is_leaf=True<br/>requires_grad=True<br/>retains_grad=False<br/>grad=None"]
loss["loss = mse(y_pred, y)<br/>is_leaf=False<br/>requires_grad=True<br/>retains_grad=True<br/>grad=torch.Tensor"]
x --> matmul
W --> matmul
matmul --> z
b --> z
z --> relu
relu --> loss
y --> loss
如果在叶子张量上调用 retain_grad(),它不会产生任何操作,因为叶子张量默认情况下已经保留了它们的梯度(当 requires_grad=True 时)。如果我们对一个 requires_grad=False 的张量调用 retain_grad(),PyTorch 实际上会抛出一个错误,因为它无法存储从未计算过的梯度。
>>> x.retain_grad()
RuntimeError: can't retain_grad on Tensor that has requires_grad=False
总结表#
使用 retain_grad() 和 retains_grad 仅对非叶子节点有意义,因为对于 requires_grad=True 的叶子张量,其 grad 属性已经被填充。默认情况下,这些非叶子节点在反向传播后不会保留(存储)它们的梯度。我们可以通过重新运行前向传播、告知 PyTorch 存储梯度,然后执行反向传播来改变这一点。
下表可作为总结上述讨论的参考。以下场景是 PyTorch 张量唯一有效的组合情况。
|
|
|
|
|
|---|---|---|---|---|
|
|
|
将 |
抛出错误 |
|
|
|
将 |
无操作(已保留) |
|
|
|
无操作 |
将 |
|
|
|
无操作 |
无操作(已保留) |
结论#
在本教程中,我们介绍了 PyTorch 何时以及如何为叶子张量和非叶子张量计算梯度。通过使用 retain_grad,我们可以访问 autograd 计算图中中间张量的梯度。
如果你想了解更多关于 PyTorch 自动求导系统的工作原理,请访问下方的参考资料。如果你对本教程有任何反馈(改进建议、勘误等),请使用 PyTorch 论坛 和/或 问题跟踪器 联系我们。
参考文献#
脚本总运行时间: (0 分钟 0.324 秒)