快捷方式

ReplayBufferTrainer

class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: int | None = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, max_dims: Sequence[int] | None = None, iterate: bool = False)[源代码]

回放缓冲区钩子提供程序。

参数:
  • replay_buffer (TensorDictReplayBuffer) – 要使用的回放缓冲区。

  • batch_size (int, optional) – 从最新收集或从回放缓冲区采样数据时的批次大小。如果未提供,则将使用回放缓冲区的批次大小(对于未更改的批次大小,这是首选选项)。

  • memmap (bool, optional) – 如果为 True,则创建 memmap tensordict。默认为 False

  • device (device, optional) – 必须放置样本的设备。默认为 None

  • flatten_tensordicts (bool, optional) – 如果为 True,则 tensordicts 将被展平(或等效地使用从收集器获得的有效掩码进行掩码),然后传递给回放缓冲区。否则,除了填充外,不会进行其他转换(请参阅下面的 max_dims 参数)。默认为 False

  • max_dims (sequence of int, optional) – 如果 flatten_tensordicts 设置为 False,这将是一个列表,其长度为提供的 tensordicts 的 batch_size,表示每个 tensordict 的最大大小。如果提供,此大小列表将用于填充 tensordict 并使其形状匹配,然后将它们传递给回放缓冲区。如果没有最大值,应提供 -1 值。

  • iterate (bool, optional) – 如果为 True,则回放缓冲区将循环迭代。默认为 False(将调用 sample())。

示例

>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N)
>>> trainer.register_op("batch_process", rb_trainer.extend)
>>> trainer.register_op("process_optim_batch", rb_trainer.sample)
>>> trainer.register_op("post_loss", rb_trainer.update_priority)
register(trainer: Trainer, name: str = 'replay_buffer')[源代码]

Registers the hook in the trainer at a default location.

参数:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

注意

To register the hook at another location than the default, use register_op().

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源