• 文档 >
  • 为数据集使用 TensorDict
快捷方式

使用 TensorDict 加载数据集

在本教程中,我们将演示 TensorDict 如何在训练管道中高效且透明地加载和管理数据。本教程大量参考了 PyTorch 快速入门教程,但进行了修改以演示 TensorDict 的用法。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模块包含许多方便的预准备数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每张图像都是一件衣服,目标是对图像中的衣服类型进行分类(例如,“包”、“运动鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

我们将创建两个 tensordicts,分别用于训练和测试数据。我们创建内存映射张量来保存数据。这将使我们能够有效地从磁盘加载转换后的数据批次,而不是反复加载和转换单个图像。

首先,我们创建 MemoryMappedTensor 容器。

training_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(training_data), *training_data[0][0].squeeze().shape),
            dtype=torch.float32,
        ),
        "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
    },
    batch_size=[len(training_data)],
    device=device,
)
test_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
        ),
        "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
    },
    batch_size=[len(test_data)],
    device=device,
)

然后,我们可以迭代数据来填充内存映射张量。这需要一些时间,但提前执行转换将节省后续训练期间的重复工作。

for i, (img, label) in enumerate(training_data):
    training_data_td[i] = TensorDict({"images": img, "targets": label}, [])

for i, (img, label) in enumerate(test_data):
    test_data_td[i] = TensorDict({"images": img, "targets": label}, [])

DataLoaders

我们将从 torchvision 提供的 Datasets 以及我们的内存映射 TensorDicts 创建 DataLoaders。

由于 TensorDict 实现 __len____getitem__(以及 __getitems__),我们可以像使用 map 风格 Dataset 一样直接从中创建 DataLoader。请注意,由于 TensorDict 已经能够处理批次索引,因此无需进行 collate,所以我们传递身份函数作为 collate_fn

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_td = DataLoader(  # noqa: TOR401
    training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader(  # noqa: TOR401
    test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)

Model

我们使用与 快速入门教程 相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

优化参数

我们将使用随机梯度下降和交叉熵损失来优化模型的参数。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我们基于 TensorDict 的 DataLoader 的训练循环非常相似,我们只需调整如何解包数据,以使用 TensorDict 提供的更明确的基于键的检索。.contiguous() 方法会加载存储在 memmap 张量中的数据。

def train_td(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data["images"].contiguous(), data["targets"].contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_td(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch["images"].contiguous(), batch["targets"].contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_td:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
    test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
    fields={
        images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
        targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.299234 [    0/60000]
loss: 2.295493 [ 6400/60000]
loss: 2.278620 [12800/60000]
loss: 2.271223 [19200/60000]
loss: 2.247178 [25600/60000]
loss: 2.218988 [32000/60000]
loss: 2.231889 [38400/60000]
loss: 2.200287 [44800/60000]
loss: 2.193725 [51200/60000]
loss: 2.164530 [57600/60000]
Test Error:
 Accuracy: 45.6%, Avg loss: 2.159090

Epoch 2
-------------------------
loss: 2.162431 [    0/60000]
loss: 2.154887 [ 6400/60000]
loss: 2.103462 [12800/60000]
loss: 2.123008 [19200/60000]
loss: 2.054369 [25600/60000]
loss: 2.003411 [32000/60000]
loss: 2.038424 [38400/60000]
loss: 1.957570 [44800/60000]
loss: 1.963327 [51200/60000]
loss: 1.897183 [57600/60000]
Test Error:
 Accuracy: 50.2%, Avg loss: 1.885813

Epoch 3
-------------------------
loss: 1.910536 [    0/60000]
loss: 1.879232 [ 6400/60000]
loss: 1.769016 [12800/60000]
loss: 1.822767 [19200/60000]
loss: 1.691695 [25600/60000]
loss: 1.651822 [32000/60000]
loss: 1.691600 [38400/60000]
loss: 1.586301 [44800/60000]
loss: 1.613930 [51200/60000]
loss: 1.524666 [57600/60000]
Test Error:
 Accuracy: 59.4%, Avg loss: 1.525683

Epoch 4
-------------------------
loss: 1.579991 [    0/60000]
loss: 1.549503 [ 6400/60000]
loss: 1.406767 [12800/60000]
loss: 1.491792 [19200/60000]
loss: 1.358662 [25600/60000]
loss: 1.361107 [32000/60000]
loss: 1.387846 [38400/60000]
loss: 1.308891 [44800/60000]
loss: 1.339235 [51200/60000]
loss: 1.256045 [57600/60000]
Test Error:
 Accuracy: 63.1%, Avg loss: 1.268251

Epoch 5
-------------------------
loss: 1.333215 [    0/60000]
loss: 1.319800 [ 6400/60000]
loss: 1.163411 [12800/60000]
loss: 1.274050 [19200/60000]
loss: 1.137881 [25600/60000]
loss: 1.170909 [32000/60000]
loss: 1.197151 [38400/60000]
loss: 1.133853 [44800/60000]
loss: 1.164905 [51200/60000]
loss: 1.095820 [57600/60000]
Test Error:
 Accuracy: 64.7%, Avg loss: 1.105006

TensorDict training done! time:  8.6405 s
Epoch 1
-------------------------
loss: 2.307087 [    0/60000]
loss: 2.300231 [ 6400/60000]
loss: 2.279642 [12800/60000]
loss: 2.272347 [19200/60000]
loss: 2.264385 [25600/60000]
loss: 2.231639 [32000/60000]
loss: 2.244270 [38400/60000]
loss: 2.214279 [44800/60000]
loss: 2.211898 [51200/60000]
loss: 2.185143 [57600/60000]
Test Error:
 Accuracy: 39.1%, Avg loss: 2.176746

Epoch 2
-------------------------
loss: 2.188481 [    0/60000]
loss: 2.182607 [ 6400/60000]
loss: 2.123818 [12800/60000]
loss: 2.141249 [19200/60000]
loss: 2.105676 [25600/60000]
loss: 2.046707 [32000/60000]
loss: 2.085586 [38400/60000]
loss: 2.010659 [44800/60000]
loss: 2.012547 [51200/60000]
loss: 1.961141 [57600/60000]
Test Error:
 Accuracy: 58.4%, Avg loss: 1.946590

Epoch 3
-------------------------
loss: 1.972843 [    0/60000]
loss: 1.953317 [ 6400/60000]
loss: 1.836205 [12800/60000]
loss: 1.878307 [19200/60000]
loss: 1.781441 [25600/60000]
loss: 1.727695 [32000/60000]
loss: 1.764617 [38400/60000]
loss: 1.653670 [44800/60000]
loss: 1.671720 [51200/60000]
loss: 1.587787 [57600/60000]
Test Error:
 Accuracy: 61.0%, Avg loss: 1.585833

Epoch 4
-------------------------
loss: 1.643045 [    0/60000]
loss: 1.611038 [ 6400/60000]
loss: 1.455160 [12800/60000]
loss: 1.528783 [19200/60000]
loss: 1.401300 [25600/60000]
loss: 1.396774 [32000/60000]
loss: 1.418147 [38400/60000]
loss: 1.327761 [44800/60000]
loss: 1.360300 [51200/60000]
loss: 1.266772 [57600/60000]
Test Error:
 Accuracy: 63.4%, Avg loss: 1.290386

Epoch 5
-------------------------
loss: 1.366724 [    0/60000]
loss: 1.343667 [ 6400/60000]
loss: 1.177910 [12800/60000]
loss: 1.279343 [19200/60000]
loss: 1.149942 [25600/60000]
loss: 1.179350 [32000/60000]
loss: 1.199123 [38400/60000]
loss: 1.127161 [44800/60000]
loss: 1.164455 [51200/60000]
loss: 1.082756 [57600/60000]
Test Error:
 Accuracy: 64.6%, Avg loss: 1.108402

Training done! time:  35.2981 s

脚本总运行时间: (0 分 57.427 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源