注意
跳转到末尾 下载完整的示例代码。
使用 TensorDict 处理数据集¶
在本教程中,我们将演示如何使用 TensorDict
来高效且透明地加载和管理训练管道中的数据。本教程在很大程度上借鉴了 PyTorch 快速入门教程,但进行了修改以展示 TensorDict
的用法。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets 模块包含许多方便的预准备数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每张图片都是一件服装,目标是根据图片对服装的类型进行分类(例如,“包”、“运动鞋”等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
我们将创建两个 tensordict,分别用于训练数据和测试数据。我们创建内存映射张量来存储数据。这将使我们能够有效地从磁盘加载转换后数据的批次,而不是反复加载和转换单个图像。
首先,我们创建 MemoryMappedTensor
容器。
training_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(training_data), *training_data[0][0].squeeze().shape),
dtype=torch.float32,
),
"targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
},
batch_size=[len(training_data)],
device=device,
)
test_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
),
"targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
},
batch_size=[len(test_data)],
device=device,
)
然后,我们可以遍历数据来填充内存映射张量。这需要一些时间,但提前执行转换将在后续的训练过程中节省重复的工作。
DataLoaders¶
我们将从 torchvision 提供的 Datasets 创建 DataLoaders,以及从我们的内存映射 TensorDicts 创建 DataLoaders。
TensorDict
实现 __len__
和 __getitem__
(以及 __getitems__
),因此我们可以像使用 map-style Dataset 一样使用它,并直接从中创建 DataLoader
。请注意,由于 TensorDict
已经可以处理批处理索引,因此无需进行 collate 操作,所以我们将恒等函数作为 collate_fn
传递。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
Model¶
我们使用了与 快速入门教程 中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
优化参数¶
我们将使用随机梯度下降和交叉熵损失来优化模型的参数。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
基于 TensorDict 的 DataLoader 的训练循环非常相似,我们只需调整数据解包的方式,以使用 TensorDict 提供的更明确的基于键的检索。`.contiguous()` 方法会加载存储在 memmap 张量中的数据。
def train_td(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data["images"].contiguous(), data["targets"].contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_td(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch["images"].contiguous(), batch["targets"].contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_td:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
fields={
images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.295151 [ 0/60000]
loss: 2.287965 [ 6400/60000]
loss: 2.263263 [12800/60000]
loss: 2.266189 [19200/60000]
loss: 2.250629 [25600/60000]
loss: 2.211262 [32000/60000]
loss: 2.238146 [38400/60000]
loss: 2.191725 [44800/60000]
loss: 2.188123 [51200/60000]
loss: 2.170130 [57600/60000]
Test Error:
Accuracy: 39.0%, Avg loss: 2.150614
Epoch 2
-------------------------
loss: 2.159945 [ 0/60000]
loss: 2.154417 [ 6400/60000]
loss: 2.091824 [12800/60000]
loss: 2.112408 [19200/60000]
loss: 2.068090 [25600/60000]
loss: 1.999538 [32000/60000]
loss: 2.045922 [38400/60000]
loss: 1.959862 [44800/60000]
loss: 1.966230 [51200/60000]
loss: 1.905551 [57600/60000]
Test Error:
Accuracy: 56.3%, Avg loss: 1.887337
Epoch 3
-------------------------
loss: 1.927120 [ 0/60000]
loss: 1.896457 [ 6400/60000]
loss: 1.773539 [12800/60000]
loss: 1.814664 [19200/60000]
loss: 1.712839 [25600/60000]
loss: 1.662633 [32000/60000]
loss: 1.699288 [38400/60000]
loss: 1.595371 [44800/60000]
loss: 1.624663 [51200/60000]
loss: 1.525723 [57600/60000]
Test Error:
Accuracy: 60.1%, Avg loss: 1.524303
Epoch 4
-------------------------
loss: 1.603213 [ 0/60000]
loss: 1.561387 [ 6400/60000]
loss: 1.406400 [12800/60000]
loss: 1.476314 [19200/60000]
loss: 1.357771 [25600/60000]
loss: 1.358757 [32000/60000]
loss: 1.380468 [38400/60000]
loss: 1.304820 [44800/60000]
loss: 1.337895 [51200/60000]
loss: 1.244839 [57600/60000]
Test Error:
Accuracy: 63.3%, Avg loss: 1.256509
Epoch 5
-------------------------
loss: 1.343973 [ 0/60000]
loss: 1.319552 [ 6400/60000]
loss: 1.153267 [12800/60000]
loss: 1.254595 [19200/60000]
loss: 1.123201 [25600/60000]
loss: 1.161435 [32000/60000]
loss: 1.184613 [38400/60000]
loss: 1.125289 [44800/60000]
loss: 1.156805 [51200/60000]
loss: 1.084467 [57600/60000]
Test Error:
Accuracy: 64.9%, Avg loss: 1.091612
TensorDict training done! time: 8.4945 s
Epoch 1
-------------------------
loss: 2.299966 [ 0/60000]
loss: 2.291062 [ 6400/60000]
loss: 2.265493 [12800/60000]
loss: 2.273356 [19200/60000]
loss: 2.247992 [25600/60000]
loss: 2.214662 [32000/60000]
loss: 2.228931 [38400/60000]
loss: 2.185137 [44800/60000]
loss: 2.188732 [51200/60000]
loss: 2.170628 [57600/60000]
Test Error:
Accuracy: 42.6%, Avg loss: 2.149621
Epoch 2
-------------------------
loss: 2.152856 [ 0/60000]
loss: 2.150230 [ 6400/60000]
loss: 2.082802 [12800/60000]
loss: 2.113469 [19200/60000]
loss: 2.062010 [25600/60000]
loss: 1.995835 [32000/60000]
loss: 2.027980 [38400/60000]
loss: 1.938653 [44800/60000]
loss: 1.948907 [51200/60000]
loss: 1.899682 [57600/60000]
Test Error:
Accuracy: 54.9%, Avg loss: 1.875343
Epoch 3
-------------------------
loss: 1.899561 [ 0/60000]
loss: 1.883063 [ 6400/60000]
loss: 1.748965 [12800/60000]
loss: 1.804443 [19200/60000]
loss: 1.698108 [25600/60000]
loss: 1.639669 [32000/60000]
loss: 1.662723 [38400/60000]
loss: 1.552907 [44800/60000]
loss: 1.583105 [51200/60000]
loss: 1.501382 [57600/60000]
Test Error:
Accuracy: 59.5%, Avg loss: 1.500250
Epoch 4
-------------------------
loss: 1.557939 [ 0/60000]
loss: 1.537942 [ 6400/60000]
loss: 1.371437 [12800/60000]
loss: 1.465200 [19200/60000]
loss: 1.346894 [25600/60000]
loss: 1.331310 [32000/60000]
loss: 1.351684 [38400/60000]
loss: 1.263354 [44800/60000]
loss: 1.307001 [51200/60000]
loss: 1.231721 [57600/60000]
Test Error:
Accuracy: 62.6%, Avg loss: 1.241147
Epoch 5
-------------------------
loss: 1.311746 [ 0/60000]
loss: 1.304771 [ 6400/60000]
loss: 1.124054 [12800/60000]
loss: 1.252720 [19200/60000]
loss: 1.126270 [25600/60000]
loss: 1.140746 [32000/60000]
loss: 1.170960 [38400/60000]
loss: 1.091258 [44800/60000]
loss: 1.138590 [51200/60000]
loss: 1.081223 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.083943
Training done! time: 35.4081 s
脚本总运行时间: (0 分钟 57.057 秒)