SliceSamplerWithoutReplacement¶
- class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: int | None = None, slice_len: int | None = None, drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, use_gpu: bool | torch.device = False)[源代码]¶
在不重复抽样的情况下,给定开始和停止信号,沿着第一个维度对数据切片进行抽样。
在此上下文中,
不重复抽样
意味着在计数器自动重置之前,同一个元素(不是轨迹)不会被抽取两次。然而,在单个样本中,只会出现一个给定轨迹的切片(请参阅下面的示例)。此类应与静态回放缓冲区一起使用,或在两个回放缓冲区扩展之间使用。扩展回放缓冲区将重置采样器,目前不允许连续不重复抽样。
注意
SliceSamplerWithoutReplacement 检索轨迹索引可能会很慢。为了加速其执行,请优先使用 end_key 而不是 traj_key,并考虑以下关键字参数:
compile
、cache_values
和use_gpu
。- 关键字参数:
drop_last (bool, optional) – 如果为
True
,则将删除最后一个不完整的样本(如果有)。如果为False
,则将保留最后一个样本。默认为False
。num_slices (int) – 要抽样的切片数量。批量大小必须大于或等于
num_slices
参数。与slice_len
互斥。slice_len (int) – 要抽样的切片的长度。批量大小必须大于或等于
slice_len
参数,并且可以被其整除。与num_slices
互斥。end_key (NestedKey, optional) – 指示轨迹(或回合)结束的键。默认为
("next", "done")
。traj_key (NestedKey, optional) – 指示轨迹的键。默认为
"episode"
(在 TorchRL 的数据集中常用)。ends (torch.Tensor, optional) – 一个包含运行结束信号的一维布尔张量。当
end_key
或traj_key
获取成本高昂,或该信号易于获得时使用。必须与cache_values=True
一起使用,并且不能与end_key
或traj_key
结合使用。trajectories (torch.Tensor, optional) – 一个包含运行 ID 的一维整数张量。当
end_key
或traj_key
获取成本高昂,或该信号易于获得时使用。必须与cache_values=True
一起使用,并且不能与end_key
或traj_key
结合使用。truncated_key (NestedKey, optional) – 如果不为
None
,则此参数指示截断信号应写入输出数据的位置。这用于告知值估计器提供的轨迹在哪里中断。默认为("next", "truncated")
。此功能仅适用于TensorDictReplayBuffer
实例(否则,截断键将在sample()
方法返回的信息字典中)。strict_length (bool, optional) – 如果为
False
,则允许长度小于 slice_len(或 batch_size // num_slices)的轨迹出现在批次中。如果为True
,则将过滤掉长度不足的轨迹。请注意,这可能导致有效的 batch_size 短于请求的 batch_size!可以使用split_trajectories()
来拆分轨迹。默认为True
。shuffle (bool, optional) – 如果为
False
,则不打乱轨迹的顺序。默认为True
。compile (bool 或 dict of kwargs, optional) – 如果为
True
,则sample()
方法的瓶颈将使用compile()
进行编译。也可以通过此参数将关键字参数传递给 torch.compile。默认为False
。use_gpu (bool 或 torch.device) – 如果为
True
(或传递了设备),则将使用加速器来检索轨迹起始点的索引。当缓冲区内容很大时,这可以显著加速采样。默认为False
。
注意
要恢复存储中的轨迹分割,
SliceSamplerWithoutReplacement
将首先尝试在存储中找到traj_key
条目。如果找不到,将使用end_key
来重建回合。示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1000), ... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each ... sampler=SliceSamplerWithoutReplacement(num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> # since we want trajectories of 32 transitions but there are only 4 episodes to >>> # sample from, we only get 4 x 32 = 128 transitions in this batch >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique())
SliceSamplerWithoutReplacement
与大多数 TorchRL 的数据集默认兼容,并允许用户以类似数据加载器的方式使用数据集。示例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSamplerWithoutReplacement >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, ... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices)) >>> # the last sample is kept, since drop_last=False by default >>> for i, batch in enumerate(data): ... print(batch.get("episode").unique()) tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24]) tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22]) tensor([ 0, 3, 4, 20, 23])
当请求大量总样本,但轨迹数量很少且跨度很小时,批次最多只会包含每个轨迹的一个样本。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(51), ... "episode": torch.ones(51),}, ... batch_size=[51] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[14, 15, 16, 17, 18], [ 3, 4, 5, 6, 7]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]]) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[ 4, 5, 6, 7, 8], [26, 27, 28, 29, 30]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]])