CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, padding='same', padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None)[source]¶
将连续的观测帧连接成单个张量。
此转换对于在观测特征中创建运动或速度感很有用。它也可以与需要访问过去观测值的模型一起使用,例如 Transformer 等。它最初在“Playing Atari with Deep Reinforcement Learning”中提出(https://arxiv.org/pdf/1312.5602.pdf)。
当在转换后的环境中使用时,
CatFrames
是一个有状态的类,可以通过调用reset()
方法将其重置为其初始状态。此方法接受包含 `"_reset"` 条目的 tensordict,该条目指示要重置的缓冲区。- 参数:
N (int) – 要连接的观测次数。
dim (int) – 用于连接观测的维度。应为负数,以确保其与不同 batch_size 的环境兼容。
in_keys (Sequence of NestedKey, optional) – 指向需要连接的帧的键。默认为 [“pixels”]。
out_keys (Sequence of NestedKey, optional) – 指向输出写入位置的键。默认为 in_keys 的值。
padding (str, optional) – 填充方法。可以是
"same"
或"constant"
。默认为"same"
,即第一个值用于填充。padding_value (
float
, optional) – 如果padding="constant"
,则用于填充的值。默认为 0。as_inverse (bool, optional) – 如果为
True
,则转换将作为逆转换应用。默认为False
。reset_key (NestedKey, optional) – 将用作部分重置指示器的重置键。必须是唯一的。如果未提供,则默认为父环境的唯一重置键(如果它只有一个),否则将引发异常。
done_key (NestedKey, optional) – 将用作部分完成指示器的完成键。必须是唯一的。如果未提供,则默认为
"done"
。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
CatFrames 转换也可以离线使用,以在不同规模下重现在线帧连接的效果(或为了限制内存消耗)。以下示例提供了完整的图景,以及
torchrl.data.ReplayBuffer
的用法示例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
注意
CatFrames
目前仅支持根目录下的"done"
信号。嵌套的done
,例如在 MARL 设置中发现的,目前不受支持。如果需要此功能,请在 TorchRL 存储库上提出问题。注意
在回放缓冲区中存储帧堆栈会显著增加内存消耗(增加 N 倍)。为了缓解这种情况,您可以直接将轨迹存储在回放缓冲区中,并在采样时应用
CatFrames
。此方法涉及对存储的轨迹进行切片采样,然后应用帧堆叠转换。为了方便起见,CatFrames
提供了一个make_rb_transform_and_sampler()
方法,该方法创建一个修改后的转换版本,适用于回放缓冲区
一个相应的
SliceSampler
用于回放缓冲区
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不会在任何时候在 `env.step` 中调用。但是,它会在 `sample()` 中调用。
注意
forward
也通过使用dispatch
将参数名称转换为键来处理常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- make_rb_transform_and_sampler(batch_size: int, **sampler_kwargs) tuple[Transform, torchrl.data.replay_buffers.SliceSampler] [source]¶
创建用于回放缓冲区在存储帧堆叠数据时使用的转换和采样器。
此方法通过避免将整个帧堆栈存储在缓冲区中来帮助减少存储数据中的冗余。相反,它创建一个转换,在采样期间动态堆叠帧,并创建一个采样器来确保维护正确的序列长度。
- 参数:
batch_size (int) – 采样器使用的批处理大小。
**sampler_kwargs – 传递给
SliceSampler
构造函数的其他关键字参数。
- 返回:
transform (Transform): 一个在采样时动态堆叠帧的转换。
sampler (SliceSampler): 一个确保维护正确序列长度的采样器。
- 返回类型:
一个包含
示例
>>> env = TransformedEnv(...) >>> catframes = CatFrames(N=4, ...) >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
注意
使用图像时,建议在前一个
ToTensorImage
转换中使用不同的in_keys
和out_keys
。这确保了存储在缓冲区中的张量与其处理后的对应物是分开的,而我们不希望存储后者。对于非图像数据,请考虑在CatFrames
之前插入一个RenameTransform
来创建一个将在缓冲区中存储的数据副本。注意
将转换添加到回放缓冲区时,应注意也要传递
CatFrames
之前的转换,例如ToTensorImage
或UnsqueezeTransform
,以便CatFrames
看到的格式数据与数据收集期间的格式相同。注意
有关更完整的示例,请参阅 torchrl 的 github 存储库 examples 文件夹:https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范