SafeProbabilisticTensorDictSequential¶
- class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[来源]¶
tensordict.nn.ProbabilisticTensorDictSequential
的子类,它接受TensorSpec
作为参数来控制输出域。与
TensorDictSequential
类似,但强制要求序列中的最后一个模块是ProbabilisticTensorDictModule
,并且还公开了get_dist
方法,以从ProbabilisticTensorDictModule
中恢复分布对象。- 参数:
modules (TensorDictModule 的可迭代对象) – 按顺序运行的 TensorDictModule 实例的有序序列,最后一个必须是 ProbabilisticTensorDictModule。
partial_tolerant (bool, optional) – 如果为
True
,则输入的 tensordict 可以缺少某些键。如果是这样,将执行那些可以根据存在的键执行的模块。此外,如果输入的 tensordict 是 tensordict 的懒惰堆叠,并且 partial_tolerant 为True
,并且堆叠缺少必需的键,则 TensorDictSequential 将扫描子 tensordict,查找具有必需键的那些(如果存在)。