DeviceCastTransform¶
- class torchrl.envs.transforms.DeviceCastTransform(device, orig_device=None, *, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[源代码]¶
将数据从一个设备移动到另一个设备。
- 参数:
device (torch.device 或 等效项) – 目标设备(在环境或缓冲区外部)。
orig_device (torch.device 或 等效项) – 源设备(在环境或缓冲区内部)。如果未指定且存在父环境,则从父环境中检索。在所有其他情况下,它将保持未指定。
- 关键字参数:
in_keys (NestedKey 列表) – 要映射到不同设备的项目列表。默认为
None
。out_keys (NestedKey 列表) – 映射到设备的项目的输出名称。默认为
in_keys
的值。in_keys_inv (NestedKey 列表) – 要映射到不同设备的项目列表。
in_keys_inv
是基准环境期望的名称。默认为None
。out_keys_inv (NestedKey 列表) – 映射到设备的项目的输出名称。
out_keys_inv
是从转换后的环境外部看到的键的名称。默认为in_keys_inv
的值。
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... }, [], device="cpu:0") >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) >>> td = transform(td) >>> print(td.device) cpu:2
- forward(tensordict: TensorDictBase = None) TensorDictBase [源代码]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不会在任何时候在 env.step 中调用。但是,它会在
sample()
中调用。注意
forward
方法也使用dispatch
正常处理关键字参数,将参数名称映射到键。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_action_spec(full_action_spec: Composite) Composite [源代码]¶
转换动作规范,使结果规范与变换映射匹配。
- 参数:
action_spec (TensorSpec) – 变换前的规范
- 返回:
转换后的预期规范
- transform_done_spec(full_done_spec: Composite) Composite [源代码]¶
变换 done spec,使结果 spec 与变换映射匹配。
- 参数:
done_spec (TensorSpec) – 变换前的 spec
- 返回:
转换后的预期规范
- transform_input_spec(input_spec: Composite) Composite [源代码]¶
转换输入规范,使结果规范与转换映射匹配。
- 参数:
input_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范
- transform_observation_spec(observation_spec: Composite) Composite [源代码]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范
- transform_output_spec(output_spec: Composite) Composite [源代码]¶
转换输出规范,使结果规范与转换映射匹配。
此方法通常应保持不变。更改应通过
transform_observation_spec()
、transform_reward_spec()
和transform_full_done_spec()
来实现。 :param output_spec: 转换前的 spec :type output_spec: TensorSpec- 返回:
转换后的预期规范
- transform_reward_spec(full_reward_spec: Composite) Composite [源代码]¶
转换奖励的 spec,使其与变换映射匹配。
- 参数:
reward_spec (TensorSpec) – 变换前的 spec
- 返回:
转换后的预期规范
- transform_state_spec(full_state_spec: Composite) Composite [源代码]¶
转换状态规范,使结果规范与变换映射匹配。
- 参数:
state_spec (TensorSpec) – 变换前的规范
- 返回:
转换后的预期规范