注意
转到末尾 下载完整的示例代码。
简介 || 张量 || Autograd || 构建模型 || TensorBoard 支持 || 训练模型 || 模型理解
使用 PyTorch 进行训练#
创建于:2021 年 11 月 30 日 | 最后更新:2023 年 5 月 31 日 | 最后验证:2024 年 11 月 05 日
请跟随下面的视频或在 youtube 上观看。
简介#
在过去的视频中,我们讨论并演示了:
使用 torch.nn 模块的神经网络层和函数构建模型
自动梯度计算的机制,这是基于梯度的模型训练的核心
使用 TensorBoard 可视化训练进度和其他活动
在本视频中,我们将为您的工具箱添加一些新工具:
我们将熟悉 dataset 和 dataloader 的抽象概念,以及它们如何在训练循环中简化将数据馈送到模型的过程
我们将讨论特定的损失函数以及何时使用它们
我们将了解 PyTorch 优化器,它们实现了根据损失函数结果调整模型权重的算法
最后,我们将把所有这些内容整合在一起,并实际查看完整的 PyTorch 训练循环。
Dataset 和 DataLoader#
Dataset
和 DataLoader
类封装了将数据从存储中提取并在批次中暴露给训练循环的过程。
Dataset
负责访问和处理单个数据实例。
DataLoader
从 Dataset
中提取数据实例(自动或使用您定义的采样器),将它们收集成批次,并返回给训练循环使用。DataLoader
可以处理各种数据集,无论它们包含何种类型的数据。
在本教程中,我们将使用 TorchVision 提供的 Fashion-MNIST 数据集。我们使用 torchvision.transforms.Normalize()
来零均值化和标准化图像块内容的分布,并下载训练和验证数据拆分。
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:12, 364kB/s]
1%| | 197k/26.4M [00:00<00:45, 577kB/s]
3%|▎ | 852k/26.4M [00:00<00:13, 1.97MB/s]
13%|█▎ | 3.38M/26.4M [00:00<00:03, 6.69MB/s]
21%|██ | 5.51M/26.4M [00:00<00:02, 8.50MB/s]
39%|███▉ | 10.4M/26.4M [00:01<00:01, 14.8MB/s]
61%|██████ | 16.1M/26.4M [00:01<00:00, 23.7MB/s]
74%|███████▍ | 19.5M/26.4M [00:01<00:00, 22.4MB/s]
92%|█████████▏| 24.3M/26.4M [00:01<00:00, 28.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 17.1MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 328kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 363kB/s]
4%|▎ | 164k/4.42M [00:00<00:09, 470kB/s]
16%|█▋ | 721k/4.42M [00:00<00:02, 1.66MB/s]
65%|██████▌ | 2.88M/4.42M [00:00<00:00, 5.73MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.09MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 48.6MB/s]
Training set has 60000 instances
Validation set has 10000 instances
一如既往,让我们将数据可视化作为健全性检查。
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))

T-shirt/top Dress Bag Trouser
模型#
我们在示例中使用模型是 LeNet-5 的一个变体——如果您观看了本系列之前的视频,应该会很熟悉。
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
损失函数#
在本例中,我们将使用交叉熵损失。为了演示目的,我们将创建模拟输出和标签值的批次,通过损失函数运行它们,并检查结果。
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.1764, 0.1619, 0.8875, 0.1805, 0.7036, 0.3358, 0.6931, 0.0364, 0.8430,
0.0767],
[0.0182, 0.0407, 0.6103, 0.2460, 0.9907, 0.0978, 0.3805, 0.2288, 0.7804,
0.3488],
[0.5060, 0.9863, 0.2487, 0.8829, 0.4201, 0.4798, 0.3157, 0.4678, 0.8579,
0.3824],
[0.0289, 0.1472, 0.3540, 0.7207, 0.7277, 0.7830, 0.7991, 0.6037, 0.3332,
0.2823]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.362215042114258
优化器#
在本例中,我们将使用带有动量的简单 随机梯度下降。
尝试对此优化方案进行一些修改可能会很有启发性。
学习率决定了优化器采取的步长大小。不同的学习率对您的训练结果(在准确性和收敛时间方面)有什么影响?
动量在多个步骤中将优化器推向最强的梯度方向。改变这个值对您的结果有什么影响?
尝试一些不同的优化算法,例如平均 SGD、Adagrad 或 Adam。您的结果有何不同?
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
训练循环#
下面是我们执行一个训练 epoch 的函数。它枚举来自 DataLoader 的数据,并在每次循环中执行以下操作:
从 DataLoader 获取一个训练数据批次
将优化器的梯度清零
执行一次推理——即,获取模型对输入批次的预测
计算该组预测与数据集上的标签之间的损失
计算学习权重的反向梯度
告诉优化器执行一个学习步骤——即,根据该批次的观察梯度,根据我们选择的优化算法调整模型的学习权重
它每 1000 个批次报告一次损失。
最后,它报告最后 1000 个批次的平均每批次损失,以便与验证运行进行比较。
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
每个 Epoch 的活动#
每个 epoch 有几件事是我们想做的:
进行验证,检查我们在未用于训练的数据集上的相对损失,并报告此信息
保存模型的副本
在这里,我们将在 TensorBoard 中进行报告。这需要您转到命令行启动 TensorBoard,并在另一个浏览器标签页中打开它。
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
EPOCH 1:
batch 1000 loss: 1.9988314672708511
batch 2000 loss: 0.9193895362205804
batch 3000 loss: 0.7223440953306853
batch 4000 loss: 0.6701288196891546
batch 5000 loss: 0.581597259292379
batch 6000 loss: 0.5741387488790788
batch 7000 loss: 0.5606894336201949
batch 8000 loss: 0.5268562002686085
batch 9000 loss: 0.4830455190619687
batch 10000 loss: 0.4596923828284489
batch 11000 loss: 0.49248371993168255
batch 12000 loss: 0.44985741829127074
batch 13000 loss: 0.46217175873881206
batch 14000 loss: 0.43543626031989696
batch 15000 loss: 0.43752258202363736
LOSS train 0.43752258202363736 valid 0.430812269449234
EPOCH 2:
batch 1000 loss: 0.4037419849329745
batch 2000 loss: 0.39954532159899825
batch 3000 loss: 0.40559865027916386
batch 4000 loss: 0.3833038282005873
batch 5000 loss: 0.3906617384759011
batch 6000 loss: 0.37513707852845984
batch 7000 loss: 0.37699090542938213
batch 8000 loss: 0.3761951707657863
batch 9000 loss: 0.36876384827977743
batch 10000 loss: 0.3674912270232453
batch 11000 loss: 0.3550953941930784
batch 12000 loss: 0.3561198842830054
batch 13000 loss: 0.35298089532964516
batch 14000 loss: 0.36295080438960575
batch 15000 loss: 0.35744707003689835
LOSS train 0.35744707003689835 valid 0.36409085988998413
EPOCH 3:
batch 1000 loss: 0.33181712135783165
batch 2000 loss: 0.3163912760570238
batch 3000 loss: 0.34046520160492216
batch 4000 loss: 0.33481437444103357
batch 5000 loss: 0.3392046003730993
batch 6000 loss: 0.3134670601064281
batch 7000 loss: 0.3283420227258321
batch 8000 loss: 0.31831632618168076
batch 9000 loss: 0.3080377133266884
batch 10000 loss: 0.3223431621780182
batch 11000 loss: 0.316907833893878
batch 12000 loss: 0.3190856933850737
batch 13000 loss: 0.3103358497906811
batch 14000 loss: 0.3308342638386093
batch 15000 loss: 0.3177569979402469
LOSS train 0.3177569979402469 valid 0.35152819752693176
EPOCH 4:
batch 1000 loss: 0.27613517147956007
batch 2000 loss: 0.3150471937338443
batch 3000 loss: 0.28435196808750696
batch 4000 loss: 0.2808161023599605
batch 5000 loss: 0.30603527145737824
batch 6000 loss: 0.30359787086395956
batch 7000 loss: 0.3039901812107855
batch 8000 loss: 0.3200699506045785
batch 9000 loss: 0.3013748074879259
batch 10000 loss: 0.29429625780676544
batch 11000 loss: 0.2937111030754313
batch 12000 loss: 0.28333777568516233
batch 13000 loss: 0.2914983189167979
batch 14000 loss: 0.28703211334972
batch 15000 loss: 0.30034291974519145
LOSS train 0.30034291974519145 valid 0.30364754796028137
EPOCH 5:
batch 1000 loss: 0.27076393893184647
batch 2000 loss: 0.27389251034590417
batch 3000 loss: 0.2679886224323709
batch 4000 loss: 0.27351233246324047
batch 5000 loss: 0.2770088437672639
batch 6000 loss: 0.2842327052370747
batch 7000 loss: 0.26012170134465123
batch 8000 loss: 0.26852404565333793
batch 9000 loss: 0.27785209107911213
batch 10000 loss: 0.2856649016340425
batch 11000 loss: 0.27832065098140446
batch 12000 loss: 0.29101420920167587
batch 13000 loss: 0.27228796556679846
batch 14000 loss: 0.29488918731830743
batch 15000 loss: 0.26893940189028853
LOSS train 0.26893940189028853 valid 0.3395138680934906
加载模型的已保存版本
saved_model = GarmentClassifier()
saved_model.load_state_dict(torch.load(PATH))
加载模型后,它就准备好满足您的任何需求了——进一步的训练、推理或分析。
请注意,如果您的模型具有影响模型结构的构造函数参数,您需要提供这些参数,并将模型配置为与保存时的状态相同。
其他资源#
在 pytorch.org 上关于数据实用程序(包括 Dataset 和 DataLoader)的文档 数据实用程序。
关于 GPU 训练使用固定内存的 固定内存使用说明。
关于 TorchVision、TorchText 和 TorchAudio 中可用的数据集的文档。
PyTorch 中可用的损失函数的文档 损失函数。
torch.optim
包的文档,其中包含优化器和相关工具,例如学习率调度torch.optim
包。关于保存和加载模型的详细 教程。
pytorch.org 的 教程 部分包含关于各种训练任务的教程,包括不同领域的分类、生成对抗网络、强化学习等。
脚本总运行时间: (3 分钟 0.804 秒)