快捷方式

distance_loss

class torchrl.objectives.distance_loss(v1: TensorLike, v2: TensorLike, loss_function: str, strict_shape: bool = True)[source]

计算两个张量之间的距离损失。

参数:
  • v1 (Tensor | TensorDict) – 一个张量或张量字典,其形状与 v2 兼容。

  • v2 (Tensor | TensorDict) – 一个张量或张量字典,其形状与 v1 兼容。

  • loss_function (str) – “l2”、“l1” 或 “smooth_l1” 中的一个,表示要使用的损失函数。

  • strict_shape (bool) – 如果为 False,则允许 v1 和 v2 具有不同的形状。默认为 True

返回:

一个张量或张量字典,形状为 v1.view_as(v2) 或 v2.view_as(v1)

其值等于两个张量之间的距离损失。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源