• 文档 >
  • 使用回放缓冲区
快捷方式

使用回放缓冲区

作者Vincent Moens

回放缓冲区是任何强化学习或控制算法的核心组成部分。监督学习方法通常的特点是训练循环,其中数据从静态数据集中随机抽取,并依次输入模型和损失函数。在强化学习中,情况通常略有不同:数据是通过模型收集的,然后临时存储在动态结构(经验回放缓冲区)中,该结构作为损失模块的数据集。

与往常一样,缓冲区使用的上下文会极大地影响其构建方式:有些人可能希望存储轨迹,而另一些人则希望存储单个转换。在某些情况下,特定的采样策略可能更可取:某些项可能比其他项具有更高的优先级,或者以替换或不替换的方式进行采样可能很重要。计算因素也可能发挥作用,例如缓冲区的大小可能超过可用的 RAM 存储。

因此,TorchRL 的回放缓冲区是完全可组合的:虽然它们“开箱即用”,需要最少的精力来构建,但它们也支持许多自定义,例如存储类型、采样策略或数据转换。

在本教程中,您将学习

基础:构建一个标准的replay buffer

TorchRL 的回放缓冲区设计旨在优先考虑模块化、可组合性、效率和简洁性。例如,创建一个基本的回放缓冲区是一个简单的过程,如下面的示例所示

import gc

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

默认情况下,此回放缓冲区的尺寸为 1000。让我们通过使用 extend() 方法填充我们的缓冲区来检查这一点

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))

我们使用了 extend() 方法,该方法设计用于一次添加多个项。如果传递给 extend 的对象具有多个维度,则其第一个维度将被视为在缓冲区中拆分为单独的元素。

这本质上意味着,在将多维张量或 tensordict 添加到缓冲区时,缓冲区在计算其内存中的元素数量时只考虑第一个维度。如果传递的对象不可迭代,则会抛出异常。

要一次添加一个项目,应改用 add() 方法。

自定义存储

我们看到缓冲区被限制为我们传递给它的前 1000 个元素。要更改大小,我们需要自定义我们的存储。

TorchRL 提供三种类型的存储

  • ListStorage 将元素独立地存储在列表中。它支持任何数据类型,但这种灵活性以效率为代价;

  • LazyTensorStorage 将张量数据结构连续存储。它自然地与 TensorDict(或 tensorclass)对象配合使用。存储是按张量连续存储的,这意味着采样效率将高于使用列表时,但隐含的限制是传递给它的任何数据都必须具有与用于实例化缓冲区的第一个批次数据相同的基本属性(如形状和 dtype)。传递不符合此要求的数据将引发异常或导致某些未定义行为。

  • LazyMemmapStorage 的工作方式类似于 LazyTensorStorage,因为它也是惰性的(即,它期望第一个批次数据用于实例化),并且它要求每个存储的批次具有匹配的形状和 dtype 的数据。这种存储的独特之处在于它指向磁盘文件(或使用文件系统存储),这意味着它可以支持非常大的数据集,同时仍然以连续的方式访问数据。

让我们看看如何使用这些存储中的每一种

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

带有列表存储的缓冲区可以存储任何类型的数据(但我们必须更改 collate_fn,因为默认值需要数值数据)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))

由于 ListStorage 的假设最少,因此它是 TorchRL 中的默认存储。

一个 LazyTensorStorage 可以连续存储数据。在处理中等大小的复杂但不改变的数据结构时,这应该是首选选项。

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

让我们创建一个包含 2 个存储张量的大小为 torch.Size([3]) 的数据批次。

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)

第一次调用 extend() 将实例化存储。数据的第一个维度被解绑为单独的数据点。

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")

让我们从缓冲区采样并打印数据。

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])

一个 LazyMemmapStorage 以相同的方式创建。我们也可以自定义磁盘上的存储位置。

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = ReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir)
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    print(
        "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
    )
    print(
        "the ('b', 'c') tensor is stored in",
        buffer_lazymemmap._storage._storage["b", "c"].filename,
    )
    sample = buffer_lazytensor.sample(5)
    print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
    del buffer_lazymemmap

与 TensorDict 集成

张量位置遵循与包含它们的 TensorDict 相同的结构:这使得在训练期间轻松保存和加载缓冲区。

要充分发挥 TensorDict 作为数据载体的潜力,可以使用 TensorDictReplayBuffer 类。它的主要优点之一是它能够处理采样数据的组织,以及可能需要的任何附加信息(例如采样索引)。

它可以以与标准 ReplayBuffer 相同的方式构建,并且通常可以互换使用。

from torchrl.data import TensorDictReplayBuffer

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    sample = buffer_lazymemmap.sample()
    print("sample:", sample)
    del buffer_lazymemmap

我们的采样现在有一个额外的 "index" 键,它表示采样了哪些索引。让我们看看这些索引。

print(sample["index"])

与 tensorclass 集成

ReplayBuffer 类及其子类也原生支持 tensorclass 类,这些类可以方便地用于更显式地编码数据集。

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)

正如预期的那样,数据具有正确的类和形状!

与其它张量结构(PyTrees)集成

TorchRL 的回放缓冲区也支持任何 pytree 数据结构。PyTree 是一个任意深度的嵌套结构,由字典、列表和/或元组构成,其叶子是张量。这意味着我们可以将任何此类树状结构存储在连续内存中!可以使用各种存储:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受此类数据。

这是一个关于此功能如何工作的简短演示。

from torch.utils._pytree import tree_map

我们在 RAM 上构建我们的回放缓冲区。

rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

使用 pytrees,任何可调用对象都可以用作转换。

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

让我们检查一下我们的转换是否起作用。

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)

从缓冲区采样和迭代

回放缓冲区支持多种采样策略。

  • 如果批次大小是固定的,并且可以在构造时定义,则可以将其作为关键字参数传递给缓冲区;

  • 在批次大小固定的情况下,可以迭代回放缓冲区来收集样本;

  • 如果批次大小是动态的,则可以在运行时将其传递给 sample 方法。

采样可以使用多线程完成,但这与最后一个选项不兼容(因为它需要缓冲区提前知道下一个批次的大小)。

让我们看几个例子。

固定批次大小

如果在构造过程中传递了批次大小,则在采样时应省略它。

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()

此数据批次的大小是我们想要的(128)。

要启用多线程采样,只需在构造过程中将一个正整数传递给 prefetch 关键字参数。这应该会大大加快采样速度,尤其是在采样耗时时(例如,在使用优先采样器时)。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())

迭代固定批次大小的缓冲区

只要预定义了批次大小,我们也可以像使用常规数据加载器一样迭代缓冲区。

for i, data in enumerate(buffer_lazy):
    if i == 3:
        print(data)
        break

del buffer_lazy

由于我们的采样技术是完全随机的并且允许替换,因此所讨论的迭代器是无限的。但是,我们可以改用 SamplerWithoutReplacement,它会将我们的缓冲区转换为有限迭代器。

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我们创建一个足够大的数据以获取几个样本。

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
    continue
print(f"A total of {_i+1} batches have been collected")

del buffer_lazy

动态批次大小

与我们之前看到的相反,可以省略 batch_size 关键字参数,并直接将其传递给 sample 方法。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))

del buffer_lazy

优先回放缓冲区

TorchRL 还提供了 优先回放缓冲区 的接口。这个缓冲区类根据通过数据传递的优先级信号对数据进行采样。

虽然此工具兼容非 tensordict 数据,但我们鼓励使用 TensorDict,因为它能够轻松地在缓冲区内外传递元数据。

让我们首先看看在通用情况下如何构建一个优先回放缓冲区。必须手动设置 \(\alpha\)\(\beta\) 超参数。

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

扩展回放缓冲区会返回项的索引,我们稍后需要它们来更新优先级。

indices = rb.extend([1, "foo", None])

采样器期望每个元素都有一个优先级。添加到缓冲区时,优先级设置为默认值 1。一旦计算出优先级(通常通过损失),就必须在缓冲区中更新它。

这是通过 update_priority() 方法完成的,该方法需要索引和优先级。我们将数据集中的第二个样本的人为高优先级分配给它,以观察它对采样效果的影响。

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我们观察到从缓冲区采样主要返回第二个样本("foo")。

sample, info = rb.sample(10, return_info=True)
print(sample)

info 包含项的相对权重以及索引。

print(info)

我们看到使用优先回放缓冲区比使用常规缓冲区需要一系列额外的训练循环步骤。

  • 在收集数据并扩展缓冲区后,必须更新项的优先级;

  • 在计算损失并从损失中获得“优先级信号”后,我们必须再次更新缓冲区中项的优先级。这需要我们跟踪索引。

这极大地阻碍了缓冲区的可重用性:如果有人要编写一个可以创建优先缓冲区和常规缓冲区的训练脚本,她必须添加大量的控制流,以确保在仅使用优先缓冲区的情况下,在适当的位置调用适当的方法。

让我们看看如何使用 TensorDict 来改进这一点。我们看到 TensorDictReplayBuffer 返回的数据中添加了其相对存储索引。我们没有提到的一项功能是,此类还确保在扩展时自动将优先级信号解析到优先采样器。

这些功能的结合在几个方面简化了事情:- 扩展缓冲区时,优先级信号将自动

如果存在,则会解析,并且优先级将被准确分配;

  • 索引将存储在采样到的 tensordicts 中,便于在计算损失后更新优先级。

  • 在计算损失时,优先级信号将注册到传递给损失模块的 tensordict 中,从而可以轻松地更新权重。

    ..code - block::Python

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下代码说明了这些概念。我们构建一个带有优先采样器的回放缓冲区,并在构造函数中指示应从中获取优先级信号的条目。

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

让我们选择一个与存储索引成比例的优先级信号。

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

较高的索引应该更频繁地出现。

from matplotlib import pyplot as plt

fig = plt.hist(sample["index"].numpy())
plt.show()

在处理完我们的样本后,我们使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新优先级键。为了展示其工作原理,让我们反转采样项的优先级。

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

现在,较高的索引应该不太频繁地出现。

sample = rb.sample()

fig = plt.hist(sample["index"].numpy())
plt.show()

使用转换

存储在回放缓冲区中的数据可能未准备好呈现给损失模块。在某些情况下,收集器生成的数据可能过于庞大,无法按原样保存。例如,将图像从 uint8 转换为浮点张量,或在使用决策转换器时连接连续帧。

通过将适当的转换附加到缓冲区,可以进出缓冲区处理数据。以下是一些示例。

保存原始图像

uint8 类型张量在内存占用方面远小于我们通常输入模型的浮点张量。因此,保存原始图像可能很有用。以下脚本展示了如何构建一个只返回原始图像但使用转换后的图像进行推理的收集器,以及如何在回放缓冲区中回收这些转换。

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

让我们看看一个 rollouts。

print(env.rollout(3))

我们刚刚创建了一个产生像素的环境。这些图像经过处理以输入策略。我们希望存储原始图像,而不是它们的转换。为此,我们将向收集器附加一个转换,以选择我们希望出现的键。

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

让我们看看一个数据批次,并控制 "pixels_trsf" 键已被丢弃。

for data in collector:
    print(data)
    break

collector.shutdown()

我们使用与环境相同的转换创建了一个回放缓冲区。但是,有一个细节需要解决:在没有环境的情况下使用的转换对数据结构是不可知的。将转换附加到环境时,"next" 嵌套 tensordict 中的数据首先被转换,然后在 rollouts 执行期间复制到根目录。使用静态数据时,情况并非如此。但是,我们的数据带有嵌套的“next”tensordict,如果不是明确指示它来处理,我们的转换将会忽略它。我们手动将这些键添加到转换中。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)

我们可以检查 sample 方法是否会看到转换后的图像重新出现。

print(rb.sample())

更复杂的示例:使用 CatFrames

CatFrames 转换器通过时间展开观测值,创建 n 步过去的事件记忆,使模型能够考虑过去的事件(在 POMDP 或循环策略(如决策转换器)的情况下)。存储这些连接的帧可能会消耗大量内存。当 n 步窗口在训练和推理期间需要不同(通常更长)时,这也会有问题。我们通过在两个阶段分别执行 CatFrames 转换来解决这个问题。

from torchrl.envs import CatFrames, UnsqueezeTransform

我们为返回基于像素的观测值的环境创建了一组标准的转换。

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break

collector.shutdown()

缓冲区转换看起来与环境转换非常相似,但具有额外的 ("next", ...) 键,如前所述。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)

让我们从缓冲区采样一个批次。转换后的像素键的形状应该从最后一个维度开始的第四个维度长度为 4。

s = rb.sample(1)  # the buffer has only one element
print(s)

经过一些处理(排除未使用的键等)后,我们看到在线生成的数据和离线生成的数据匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

存储轨迹

在许多情况下,从缓冲区访问轨迹而不是简单转换是可取的。TorchRL 提供了多种实现此目的的方法。

首选方法是沿着缓冲区的第一个维度存储轨迹,并使用 SliceSampler 对这些数据批次进行采样。此类只需了解您数据结构的一些信息即可完成其工作(请注意,目前它仅与 tensordict 结构化数据兼容):切片数量或其长度以及有关情节分隔位置的信息(例如,回想一下,使用 DataCollector 时,轨迹 ID 存储在 ("collector", "traj_ids") 中)。在此简单示例中,我们构建了一个包含 4 个连续短轨迹的数据,并从中采样了 4 个切片,每个切片长度为 2(因为批次大小为 8,8 个项 // 4 个切片 = 2 个时间步)。我们还标记了这些步骤。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

gc.collect()

结论

我们已经了解了如何在 TorchRL 中使用回放缓冲区,从最简单的用法到更高级的用法,其中数据需要以特定方式进行转换或存储。您现在应该能够

  • 创建回放缓冲区,自定义其存储、采样器和转换;

  • 为您的问题选择最佳的存储类型(列表、内存或基于磁盘);

  • 最小化缓冲区的内存占用。

后续步骤

  • 查看数据 API 参考,了解 TorchRL 中的离线数据集,这些数据集基于我们的回放缓冲区 API;

  • 查看其他采样器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或其他写入器,例如 TensorDictMaxValueWriter

  • 文档 中查看如何检查回放缓冲区点。

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源