• 文档 >
  • 使用回放缓冲区
快捷方式

使用回放缓冲区

作者Vincent Moens

回放缓冲区是任何强化学习或控制算法的核心部分。监督学习方法通常的特点是有一个训练循环,其中数据从静态数据集中随机提取,并依次输入模型和损失函数。在强化学习中,情况通常略有不同:数据是使用模型收集的,然后临时存储在动态结构(经验回放缓冲区)中,该结构作为损失模块的数据集。

一如既往,缓冲区使用的上下文对其构建方式有着极大的影响:有些人可能希望存储轨迹,而另一些人则希望存储单个转换。在某些情况下,特定的采样策略可能更可取:某些项的优先级可能高于其他项,或者进行有放回或无放回采样可能很重要。计算因素也可能发挥作用,例如缓冲区的大小可能超出可用 RAM 存储。

出于这些原因,TorchRL 的回放缓冲区是完全可组合的:尽管它们附带“电池”,只需最少的努力即可构建,但它们也支持许多自定义,例如存储类型、采样策略或数据转换。

在本教程中,您将学习

基础:构建一个标准的(vanilla)回放缓冲区

TorchRL 的回放缓冲区旨在优先考虑模块化、可组合性、效率和简洁性。例如,创建一个基本的回放缓冲区是一个简单的过程,如下例所示

import gc

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

默认情况下,此回放缓冲区的大小为 1000。让我们通过使用 `extend()` 方法填充缓冲区来检查这一点

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))

我们使用了 `extend()` 方法,该方法设计用于一次添加多个项目。如果传递给 `extend` 的对象具有多个维度,则其第一个维度将被视为在缓冲区中拆分为单独的元素。

这本质上意味着,在添加多维张量或张量字典(tensordict)到缓冲区时,缓冲区在计算其内存中的元素数量时将仅关注第一个维度。如果传递的对象不可迭代,则会引发异常。

要一次添加一个项目,应改用 `add()` 方法。

自定义存储

我们看到缓冲区被限制在我们传递给它的前 1000 个元素。要更改大小,我们需要自定义存储。

TorchRL 提供三种类型的存储

  • `ListStorage` 将元素独立地存储在列表中。它支持任何数据类型,但这种灵活性是以效率为代价的;

  • `LazyTensorStorage` 将张量数据结构连续地存储起来。它自然地与 `TensorDict`(或 `tensorclass`)对象一起工作。存储在每个张量基础上是连续的,这意味着采样将比使用列表时更有效,但隐含的限制是传递给它的任何数据都必须具有与用于实例化缓冲区的第一个数据批次相同的基本属性(例如形状和 dtype)。传递不符合此要求的将引发异常或导致某些未定义行为。

  • `LazyMemmapStorage` 的工作方式类似于 `LazyTensorStorage`,即它是惰性的(也就是说,它期望第一个数据批次被实例化),并且它需要每个存储批次的数据在形状和 dtype 上匹配。这种存储的独特之处在于它指向磁盘文件(或使用文件系统存储),这意味着它可以支持非常大的数据集,同时仍然以连续的方式访问数据。

让我们看看如何使用这些存储中的每一种

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

带有列表存储的缓冲区可以存储任何类型的数据(但我们必须更改 `collate_fn`,因为默认值期望数值数据)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))

由于它是假设最少的存储,因此 `ListStorage` 是 TorchRL 中的默认存储。

`LazyTensorStorage` 可以连续存储数据。在处理大小中等的复杂但不变的数据结构时,应优先选择此选项。

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

让我们创建一个包含 2 个存储的张量的数据批次,其大小为 `torch.Size([3])`。

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)

第一次调用 `extend()` 将实例化存储。数据的第一个维度被解绑为单独的数据点。

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")

让我们从缓冲区采样,并打印数据。

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])

`LazyMemmapStorage` 的创建方式相同。我们还可以自定义磁盘上的存储位置。

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = ReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir)
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    print(
        "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
    )
    print(
        "the ('b', 'c') tensor is stored in",
        buffer_lazymemmap._storage._storage["b", "c"].filename,
    )
    sample = buffer_lazytensor.sample(5)
    print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
    del buffer_lazymemmap

与 TensorDict 集成

张量位置遵循与包含它们的 TensorDict 相同的结构:这使得在训练过程中保存和加载缓冲区变得容易。

要充分利用 `TensorDict` 作为数据载体,可以使用 `TensorDictReplayBuffer` 类。它的一个关键优点是能够处理采样数据的组织,以及可能需要的任何附加信息(例如采样索引)。

它的构建方式与标准 `ReplayBuffer` 相同,并且通常可以互换使用。

from torchrl.data import TensorDictReplayBuffer

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    sample = buffer_lazymemmap.sample()
    print("sample:", sample)
    del buffer_lazymemmap

我们的样本现在有一个额外的 `"index"` 键,它指示了采样了哪些索引。让我们看看这些索引。

print(sample["index"])

与 tensorclass 集成

ReplayBuffer 类及其关联的子类也支持与 `tensorclass` 类原生工作,`tensorclass` 类可以方便地用于以更显式的方式编码数据集。

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)

正如预期的那样,数据具有正确的类和形状!

与其他张量结构(PyTrees)集成

TorchRL 的回放缓冲区也支持任何 pytree 数据结构。PyTree 是由字典、列表和/或元组组成的任意深度的嵌套结构,其叶子是张量。这意味着可以以连续内存存储任何此类树状结构!可以使用各种存储:`TensorStorage`、`LazyMemmapStorage` 或 `LazyTensorStorage` 都接受此类数据。

这是一个展示此功能外观的简短演示。

from torch.utils._pytree import tree_map

让我们在 RAM 上构建我们的回放缓冲区。

rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

使用 pytrees,任何可调用对象都可以用作转换。

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

让我们检查一下我们的转换是否已完成工作。

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)

采样和迭代缓冲区

回放缓冲区支持多种采样策略。

  • 如果 batch-size 是固定的,并且可以在构造时定义,则可以将其作为关键字参数传递给缓冲区。

  • 使用固定的 batch-size,可以迭代回放缓冲区以收集样本。

  • 如果 batch-size 是动态的,则可以在采样方法中即时传递它。

采样可以使用多线程完成,但这与最后一个选项不兼容(因为它要求缓冲区提前知道下一个 batch 的大小)。

让我们看几个例子。

固定 batch-size

如果在构造时传递了 batch-size,则在采样时应省略它。

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()

此数据批次的尺寸是我们想要的尺寸(128)。

要启用多线程采样,只需在构造时将一个正整数传递给 `prefetch` 关键字参数。当采样耗时时(例如,在使用优先采样器时),这应该会大大加快采样速度。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())

使用固定 batch-size 迭代缓冲区

只要 batch-size 是预定义的,我们也可以像使用常规数据加载器一样迭代缓冲区。

for i, data in enumerate(buffer_lazy):
    if i == 3:
        print(data)
        break

del buffer_lazy

由于我们的采样技术是完全随机且允许替换的,因此该迭代器是无限的。但是,我们可以改用 `SamplerWithoutReplacement`,它将把我们的缓冲区转换为有限迭代器。

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我们创建足够大的数据以获得几个样本。

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
    continue
print(f"A total of {_i+1} batches have been collected")

del buffer_lazy

动态 batch-size

与我们之前看到的相反,可以省略 `batch_size` 关键字参数,并将其直接传递给 `sample` 方法。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))

del buffer_lazy

优先回放缓冲区

TorchRL 还提供了优先回放缓冲区的接口。此缓冲区类根据通过数据传递的优先级信号对数据进行采样。

尽管此工具与非 tensordict 数据兼容,但我们鼓励使用 TensorDict,因为它能够轻松地在缓冲区内外携带元数据。

让我们首先看看如何在通用情况下构建优先回放缓冲区。必须手动设置 \(\alpha\)\(\beta\) 超参数。

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

扩展回放缓冲区将返回项目索引,我们稍后将需要它们来更新优先级。

indices = rb.extend([1, "foo", None])

采样器期望每个元素都有一个优先级。添加到缓冲区时,优先级设置为默认值 1。一旦计算出优先级(通常通过损失),就必须在缓冲区中更新它。

这是通过 `update_priority()` 方法完成的,该方法需要索引和优先级。我们为数据集中的第二个样本分配了人为的高优先级,以观察它对采样的影响。

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我们观察到从缓冲区采样主要返回第二个样本(`"foo"`)。

sample, info = rb.sample(10, return_info=True)
print(sample)

信息包含项目的相对权重以及索引。

print(info)

我们看到,与常规缓冲区相比,使用优先回放缓冲区需要比训练循环中多几个步骤。

  • 收集数据并扩展缓冲区后,必须更新项目的优先级。

  • 在计算损失并从其获得“优先级信号”后,我们必须再次更新缓冲区中项目的优先级。这要求我们跟踪索引。

这大大阻碍了缓冲区的可重用性:如果有人要编写一个可以创建优先缓冲区和常规缓冲区的训练脚本,她必须添加大量的控制流,以确保在仅使用优先缓冲区的情况下,在适当的位置调用适当的方法。

让我们看看如何使用 `TensorDict` 来改进这一点。我们看到 `TensorDictReplayBuffer` 返回增强了其相对存储索引的数据。我们还没有提到的一项功能是,当存在优先采样器时,此类还可以确保将优先级信号自动解析到优先采样器。

这些功能的结合以多种方式简化了事情: - 扩展缓冲区时,如果存在优先级信号,它将被自动

解析,并且优先级将被准确分配;

  • 索引将存储在采样的张量字典中,使得在计算损失后轻松更新优先级。

  • 计算损失时,优先级信号将在传递给损失模块的张量字典中注册,从而可以轻松更新权重。

    ..code - block::Python

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下代码说明了这些概念。我们构建了一个带有优先采样器的回放缓冲区,并在构造函数中指示了应从中获取优先级信号的条目。

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

让我们选择一个与存储索引成比例的优先级信号。

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

更高的索引应该更频繁地出现。

from matplotlib import pyplot as plt

fig = plt.hist(sample["index"].numpy())
plt.show()

在处理完样本后,我们使用 `torchrl.data.TensorDictReplayBuffer.update_tensordict_priority()` 方法更新优先级键。为了演示其工作原理,让我们恢复采样项的优先级。

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

现在,更高的索引应该更频繁地出现。

sample = rb.sample()

fig = plt.hist(sample["index"].numpy())
plt.show()

使用转换

存储在回放缓冲区中的数据可能尚未准备好呈现给损失模块。在某些情况下,收集器生成的数据可能太重,无法按原样保存。例如,将图像从 `uint8` 转换为浮点张量,或在使用决策转换器时连接连续帧。

只需将适当的转换附加到缓冲区,即可在缓冲区内部和外部处理数据。以下是一些示例。

保存原始图像

类型为 `uint8` 的张量比我们通常输入模型的浮点张量在内存占用上要小得多。因此,保存原始图像可能很有用。以下脚本展示了如何构建一个仅返回原始图像但使用转换后的图像进行推理的收集器,以及如何在回放缓冲区中回收这些转换。

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

让我们看看一个 rollout。

print(env.rollout(3))

我们刚刚创建了一个产生像素的环境。这些图像经过处理后输入策略。我们希望存储原始图像,而不是它们的转换。为此,我们将一个转换附加到收集器,以选择我们希望出现的键。

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

让我们看看数据批次,并控制 `"pixels_trsf"` 键已被丢弃。

for data in collector:
    print(data)
    break

collector.shutdown()

我们创建一个与环境具有相同转换的回放缓冲区。但是,有一个细节需要解决:单独使用环境的转换是无法识别数据结构的。当将转换附加到环境时,`"next"` 嵌套张量字典中的数据首先被转换,然后在 rollout 执行期间复制到根目录。使用静态数据时,情况并非如此。但是,我们的数据带有一个嵌套的 `"next"` 张量字典,如果我们不明确指示它来处理它,我们的转换将忽略它。我们手动将这些键添加到转换中。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)

我们可以检查 `sample` 方法是否看到了转换后的图像重新出现。

print(rb.sample())

一个更复杂的例子:使用 CatFrames

`CatFrames` 转换器通过时间展开观察,创建 n 步记忆,允许模型考虑过去事件(在 POMDP 或使用决策转换器等循环策略的情况下)。存储这些连接的帧会占用大量内存。当 n 步窗口在训练和推理期间需要不同(通常更长)时,这也会有问题。我们通过在两个阶段分别执行 `CatFrames` 转换来解决此问题。

from torchrl.envs import CatFrames, UnsqueezeTransform

我们为返回基于像素的观察的环境创建了一个标准的转换列表。

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break

collector.shutdown()

缓冲区转换看起来与环境转换非常相似,但具有额外的 `("next", ...)` 键,如前所述。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)

让我们从缓冲区采样一个批次。转换后的像素键的形状应在从末尾开始的第四个维度上为 4。

s = rb.sample(1)  # the buffer has only one element
print(s)

经过一些处理(排除未使用的键等)后,我们看到在线和离线生成的数据是匹配的!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

存储轨迹

在许多情况下,期望从缓冲区访问轨迹而不是简单的转换。TorchRL 提供了多种实现此目的的方法。

首选方法是沿着缓冲区的第一个维度存储轨迹,并使用 `SliceSampler` 对这些数据批次进行采样。此类只需要关于数据结构的少量信息即可完成其工作(请注意,目前它仅与 tensordict 结构化数据兼容):切片数或其长度,以及有关如何区分各个片段的信息(例如,回想一下,使用 `DataCollector`,轨迹 ID 存储在 `("collector", "traj_ids")` 中)。在这个简单的示例中,我们构建了一个包含 4 个连续短轨迹的数据,并从中采样了 4 个切片,每个切片长度为 2(因为 batch size 为 8,而 8 个项目 // 4 个切片 = 2 个时间步)。我们还标记了这些步骤。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

gc.collect()

结论

我们已经了解了如何在 TorchRL 中使用回放缓冲区,从最简单的用法到更高级的用法,其中需要转换数据或以特定方式存储数据。您现在应该能够

  • 创建回放缓冲区,自定义其存储、采样器和转换。

  • 为您的应用程序选择最佳的存储类型(列表、内存或磁盘);

  • 最小化缓冲区的内存占用。

后续步骤

  • 查看数据 API 参考,了解 TorchRL 中的离线数据集,这些数据集基于我们的回放缓冲区 API。

  • 查看其他采样器,例如 `SamplerWithoutReplacement`、`PrioritizedSliceSampler` 和 `SliceSamplerWithoutReplacement`,或其他写入器,例如 `TensorDictMaxValueWriter`。

  • 文档中查看如何检查回放缓冲区的断点。

由 Sphinx-Gallery 生成的画廊

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源