VIPRewardTransform¶
- class torchrl.envs.transforms.VIPRewardTransform(*args, **kwargs)[源代码]¶
一个 VIP 变换,用于根据嵌入相似度计算奖励。
此类将更新奖励计算
- forward(tensordict)[源代码]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不在任何时候在 env.step 内调用。但是,它在
sample()
内调用。注意
forward
也使用dispatch
与常规关键字参数一起使用,以将参数名称强制转换为键。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_input_spec(input_spec: TensorSpec) TensorSpec [源代码]¶
转换输入规范,使结果规范与转换映射匹配。
- 参数:
input_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范