distance_loss¶
- class torchrl.objectives.distance_loss(v1: TensorLike, v2: TensorLike, loss_function: str, strict_shape: bool = True)[source]¶
计算两个张量之间的距离损失。
- 参数:
v1 (Tensor | TensorDict) – 一个张量或张量字典,其形状与 v2 兼容。
v2 (Tensor | TensorDict) – 一个张量或张量字典,其形状与 v1 兼容。
loss_function (str) – “l2”、“l1” 或 “smooth_l1” 中的一个,表示要使用的损失函数。
strict_shape (bool) – 如果为 False,则允许 v1 和 v2 具有不同的形状。默认为
True
。
- 返回:
一个张量或张量字典,形状为 v1.view_as(v2) 或 v2.view_as(v1)
其值等于两个张量之间的距离损失。