SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int | None = None, slice_len: int | None = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False)[源代码]¶
沿第一维度对数据切片进行采样,给定开始和停止信号。
此类有放回地采样子轨迹。无放回版本请参阅
SliceSamplerWithoutReplacement
。注意
SliceSampler 检索轨迹索引可能会很慢。为了加快其执行速度,请优先使用 end_key 而非 traj_key,并考虑以下关键字参数:
compile
、cache_values
和use_gpu
。- 关键字参数:
num_slices (int) – 要抽样的切片数量。批量大小必须大于或等于
num_slices
参数。与slice_len
互斥。slice_len (int) – 要抽样的切片的长度。批量大小必须大于或等于
slice_len
参数,并且可以被其整除。与num_slices
互斥。end_key (NestedKey, optional) – 指示轨迹(或回合)结束的键。默认为
("next", "done")
。traj_key (NestedKey, optional) – 指示轨迹的键。默认为
"episode"
(在 TorchRL 的数据集中常用)。ends (torch.Tensor, 可选) – 一个包含运行结束信号的一维布尔张量。当
end_key
或traj_key
获取成本高昂,或此信号易于获得时使用。必须与cache_values=True
一起使用,且不能与end_key
或traj_key
结合使用。如果提供,则假定存储已满,并且如果ends
张量的最后一个元素为False
,则相同的轨迹会跨越结束和开始。trajectories (torch.Tensor, 可选) – 一个包含运行 ID 的一维整数张量。当
end_key
或traj_key
获取成本高昂,或此信号易于获得时使用。必须与cache_values=True
一起使用,且不能与end_key
或traj_key
结合使用。如果提供,则假定存储已满,并且如果轨迹张量的最后一个元素与第一个元素相同,则相同的轨迹会跨越结束和开始。cache_values (bool, 可选) –
与静态数据集一起使用。将缓存轨迹的开始和结束信号。即使轨迹索引在调用
extend
时发生更改,也可以安全地使用此选项,因为此操作将清除缓存。警告
cache_values=True
在以下情况将无法正常工作:采样器与由另一个缓冲区扩展的存储一起使用。例如:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
cache_values=True
在以下情况将无法按预期工作:缓冲区在进程之间共享,一个进程负责写入,另一个进程负责采样,因为清除缓存只能在本地进行。truncated_key (NestedKey, optional) – 如果不为
None
,则此参数指示截断信号应写入输出数据的位置。这用于告知值估计器提供的轨迹在哪里中断。默认为("next", "truncated")
。此功能仅适用于TensorDictReplayBuffer
实例(否则,截断键将在sample()
方法返回的信息字典中)。strict_length (bool, optional) – 如果为
False
,则允许长度小于 slice_len(或 batch_size // num_slices)的轨迹出现在批次中。如果为True
,则将过滤掉长度不足的轨迹。请注意,这可能导致有效的 batch_size 短于请求的 batch_size!可以使用split_trajectories()
来拆分轨迹。默认为True
。compile (bool 或 dict of kwargs, optional) – 如果为
True
,则sample()
方法的瓶颈将使用compile()
进行编译。也可以通过此参数将关键字参数传递给 torch.compile。默认为False
。span (bool, int, Tuple[bool | int, bool | int], 可选) – 如果提供,则采样轨迹将跨越左侧和/或右侧。这意味着可能提供的元素少于所需的元素。布尔值表示每个轨迹至少会采样一个元素。整数 i 表示每个采样轨迹至少会收集 slice_len - i 个样本。使用元组可以精细控制跨越左侧(存储轨迹的开头)和右侧(存储轨迹的结尾)的跨度。
use_gpu (bool 或 torch.device) – 如果为
True
(或传递了设备),则将使用加速器来检索轨迹的起始索引。当缓冲区内容很大时,这可以显著加快采样速度。默认为False
。
注意
要恢复存储中的轨迹分割,
SliceSampler
将首先尝试在存储中查找traj_key
条目。如果找不到,将使用end_key
来重建剧集。注意
当使用 strict_length=False 时,建议使用
split_trajectories()
来分割采样轨迹。但是,如果来自同一剧集的两个样本并排放置,这可能会产生不正确的结果。为避免此问题,请考虑以下解决方案之一:使用带有切片采样器的
TensorDictReplayBuffer
实例>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSampler( ... slice_len=5, traj_key="episode",strict_length=False, ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> print(s) TensorDict( fields={ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False) >>> t = split_trajectories(s, done_key="truncated") >>> print(t["obs"]) tensor([[73, 74, 75, 76, 77], [ 0, 1, 2, 3, 0], [ 0, 1, 2, 3, 0], [41, 42, 43, 44, 45], [ 0, 1, 2, 3, 0], [67, 68, 69, 70, 71], [27, 28, 29, 30, 31], [80, 81, 82, 83, 84], [17, 18, 19, 20, 21], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
使用
SliceSamplerWithoutReplacement
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[75, 76, 77, 78, 79], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampler
与大多数 TorchRL 的数据集默认兼容示例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])