快捷方式

GAILLoss

class torchrl.objectives.GAILLoss(*args, **kwargs)[source]

TorchRL 对生成对抗性模仿学习 (GAIL) 损失的实现。

发表于 “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476>

参数:

discriminator_network (TensorDictModule) – 随机策略

关键字参数:
  • use_grad_penalty (bool, optional) – 是否使用梯度惩罚。默认为 False

  • gp_lambda (float, optional) – 梯度惩罚 lambda。默认为 10

  • reduction (str, optional) – 指定应用于输出的约简:"none" | "mean" | "sum""none":不应用约简,"mean":输出的总和除以输出中的元素数量,"sum":输出将求和。默认为 "mean"

default_keys

别名:_AcceptedKeys

forward(tensordict: TensorDictBase = None) TensorDictBase[source]

forward 方法。

计算判别器损失和梯度惩罚(如果 use_grad_penalty 设置为 True)。如果 use_grad_penalty 设置为 True,还会返回分离的梯度惩罚损失用于日志记录。要查看输入 tensordict 中预期的键以及输出中预期的键,请查看类的 “in_keys”“out_keys” 属性。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源