快捷方式

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[源代码]

tensordict.nn.ProbabilisticTensorDictSequential 的子类,它接受 TensorSpec 作为参数来控制输出域。

TensorDictSequential 类似,但强制要求序列中的最后一个模块是 ProbabilisticTensorDictModule,并且还公开了 get_dist 方法来从 ProbabilisticTensorDictModule 恢复分布对象。

参数:
  • modules (TensorDictModules 的可迭代对象) – 按顺序排列的 TensorDictModule 实例序列,以 ProbabilisticTensorDictModule 结尾,将按顺序运行。

  • partial_tolerant (bool, optional) – 如果为 True,则输入的 tensordict 可能缺少某些输入键。如果是这样,则只会执行可以根据存在的键执行的模块。此外,如果输入的 tensordict 是 tensordicts 的惰性堆栈,并且 partial_tolerant 为 True,并且堆栈不包含所需的键,那么 TensorDictSequential 将扫描子 tensordicts 查找具有所需键的 tensordicts(如果有)。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源