• 文档 >
  • 为数据集使用 tensorclass
快捷方式

使用 tensorclasses 处理数据集

在本教程中,我们将演示如何在训练管道中高效且透明地加载和管理数据。本教程很大程度上基于 PyTorch 快速入门教程,但进行了修改以演示 tensorclass 的用法。请参阅使用 TensorDict 的相关教程。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模块包含许多方便的预准备数据集。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每张图像都是一件服装,目标是根据图像对服装类型进行分类(例如,“包”、“运动鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 362kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 682kB/s]
  3%|▎         | 852k/26.4M [00:00<00:10, 2.44MB/s]
  7%|▋         | 1.93M/26.4M [00:00<00:05, 4.13MB/s]
 20%|██        | 5.41M/26.4M [00:00<00:02, 10.0MB/s]
 38%|███▊      | 10.1M/26.4M [00:00<00:01, 15.8MB/s]
 62%|██████▏   | 16.3M/26.4M [00:01<00:00, 22.0MB/s]
 85%|████████▌ | 22.5M/26.4M [00:01<00:00, 25.9MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 18.1MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 361kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 680kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.62MB/s]
 44%|████▎     | 1.93M/4.42M [00:00<00:00, 4.07MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.07MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 59.6MB/s]

Tensorclasses 是 dataclasses,就像 TensorDict 一样,它们在其内容上公开专用的 tensor 方法。当你要存储的数据结构固定且可预测时,它们是一个不错的选择。

除了指定内容,我们还可以在定义类时将相关逻辑封装为自定义方法。在这种情况下,我们将编写一个 from_dataset 类方法,它接受一个数据集作为输入,并创建一个包含来自数据集的数据的 tensorclass。我们使用内存映射的 tensor 来存储数据。这将使我们能够有效地从磁盘加载转换后的数据批次,而不是反复加载和转换单个图像。

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

我们将创建两个 tensorclasses,分别为训练数据和测试数据。请注意,这里我们承担了一些开销,因为我们要遍历整个数据集,进行转换并保存到磁盘。

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

DataLoaders

我们将从 torchvision 提供的 Datasets 以及我们的内存映射 TensorDicts 中创建 DataLoaders。

由于 TensorDict 实现 __len____getitem__(以及 __getitems__),我们可以像使用 map 风格的 Dataset 一样使用它,并直接从中创建 DataLoader。请注意,由于 TensorDict 已经可以处理批量索引,因此无需进行 collate,所以我们将身份函数作为 collate_fn 传递。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

Model

我们使用与 快速入门教程 中相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

优化参数

我们将使用随机梯度下降和交叉熵损失来优化模型的参数。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我们基于 tensorclass 的 DataLoader 的训练循环非常相似,我们只需调整数据解包方式,以适应 tensorclass 提供的更显式的基于属性的检索。.contiguous() 方法加载存储在 memmap tensor 中的数据。

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.279966 [    0/60000]
loss: 2.272364 [ 6400/60000]
loss: 2.248586 [12800/60000]
loss: 2.249137 [19200/60000]
loss: 2.230657 [25600/60000]
loss: 2.205297 [32000/60000]
loss: 2.207200 [38400/60000]
loss: 2.174660 [44800/60000]
loss: 2.160898 [51200/60000]
loss: 2.144460 [57600/60000]
Test Error:
 Accuracy: 55.0%, Avg loss: 2.124971

Epoch 2
-------------------------
loss: 2.127797 [    0/60000]
loss: 2.122406 [ 6400/60000]
loss: 2.055028 [12800/60000]
loss: 2.078694 [19200/60000]
loss: 2.029268 [25600/60000]
loss: 1.970141 [32000/60000]
loss: 1.995516 [38400/60000]
loss: 1.917498 [44800/60000]
loss: 1.913008 [51200/60000]
loss: 1.858231 [57600/60000]
Test Error:
 Accuracy: 57.6%, Avg loss: 1.841842

Epoch 3
-------------------------
loss: 1.875065 [    0/60000]
loss: 1.849101 [ 6400/60000]
loss: 1.719949 [12800/60000]
loss: 1.766835 [19200/60000]
loss: 1.659159 [25600/60000]
loss: 1.620128 [32000/60000]
loss: 1.639907 [38400/60000]
loss: 1.549244 [44800/60000]
loss: 1.569196 [51200/60000]
loss: 1.473024 [57600/60000]
Test Error:
 Accuracy: 60.1%, Avg loss: 1.482247

Epoch 4
-------------------------
loss: 1.553915 [    0/60000]
loss: 1.520971 [ 6400/60000]
loss: 1.364636 [12800/60000]
loss: 1.438522 [19200/60000]
loss: 1.320970 [25600/60000]
loss: 1.327443 [32000/60000]
loss: 1.337312 [38400/60000]
loss: 1.272601 [44800/60000]
loss: 1.303945 [51200/60000]
loss: 1.208621 [57600/60000]
Test Error:
 Accuracy: 63.5%, Avg loss: 1.230491

Epoch 5
-------------------------
loss: 1.310143 [    0/60000]
loss: 1.292831 [ 6400/60000]
loss: 1.124210 [12800/60000]
loss: 1.228162 [19200/60000]
loss: 1.105486 [25600/60000]
loss: 1.137606 [32000/60000]
loss: 1.155097 [38400/60000]
loss: 1.101383 [44800/60000]
loss: 1.138751 [51200/60000]
loss: 1.056086 [57600/60000]
Test Error:
 Accuracy: 65.0%, Avg loss: 1.074194

Tensorclass training done! time:  8.5721 s
Epoch 1
-------------------------
loss: 2.300333 [    0/60000]
loss: 2.291154 [ 6400/60000]
loss: 2.277311 [12800/60000]
loss: 2.270576 [19200/60000]
loss: 2.257171 [25600/60000]
loss: 2.219032 [32000/60000]
loss: 2.229099 [38400/60000]
loss: 2.198443 [44800/60000]
loss: 2.195091 [51200/60000]
loss: 2.163343 [57600/60000]
Test Error:
 Accuracy: 44.4%, Avg loss: 2.157291

Epoch 2
-------------------------
loss: 2.164766 [    0/60000]
loss: 2.154399 [ 6400/60000]
loss: 2.101165 [12800/60000]
loss: 2.117347 [19200/60000]
loss: 2.076215 [25600/60000]
loss: 2.006450 [32000/60000]
loss: 2.039347 [38400/60000]
loss: 1.968061 [44800/60000]
loss: 1.971254 [51200/60000]
loss: 1.899145 [57600/60000]
Test Error:
 Accuracy: 58.4%, Avg loss: 1.896424

Epoch 3
-------------------------
loss: 1.926218 [    0/60000]
loss: 1.895335 [ 6400/60000]
loss: 1.784933 [12800/60000]
loss: 1.823640 [19200/60000]
loss: 1.730518 [25600/60000]
loss: 1.671613 [32000/60000]
loss: 1.696103 [38400/60000]
loss: 1.606779 [44800/60000]
loss: 1.629869 [51200/60000]
loss: 1.523274 [57600/60000]
Test Error:
 Accuracy: 60.9%, Avg loss: 1.540868

Epoch 4
-------------------------
loss: 1.603400 [    0/60000]
loss: 1.569938 [ 6400/60000]
loss: 1.428288 [12800/60000]
loss: 1.490420 [19200/60000]
loss: 1.386604 [25600/60000]
loss: 1.371165 [32000/60000]
loss: 1.379619 [38400/60000]
loss: 1.318368 [44800/60000]
loss: 1.351309 [51200/60000]
loss: 1.244383 [57600/60000]
Test Error:
 Accuracy: 62.9%, Avg loss: 1.275016

Epoch 5
-------------------------
loss: 1.348914 [    0/60000]
loss: 1.332156 [ 6400/60000]
loss: 1.176685 [12800/60000]
loss: 1.265823 [19200/60000]
loss: 1.154234 [25600/60000]
loss: 1.170851 [32000/60000]
loss: 1.179623 [38400/60000]
loss: 1.134891 [44800/60000]
loss: 1.173710 [51200/60000]
loss: 1.079602 [57600/60000]
Test Error:
 Accuracy: 64.4%, Avg loss: 1.106302

Training done! time:  36.0195 s

脚本总运行时间: (1 分钟 2.924 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源