VecGymEnvTransform¶
- class torchrl.envs.transforms.VecGymEnvTransform(final_name: str = 'final', missing_obs_value: Any = nan)[source]¶
一个用于 GymWrapper 子类的转换,可以以一致的方式处理自动重置。
Gym、gymnasium 和 SB3 提供向量化(读取、并行或批处理)环境,它们会自动重置。发生这种情况时,由动作产生的实际观察值会保存在 info 中的一个键下。类
torchrl.envs.libs.gym.terminal_obs_reader
读取该观察值,并将其存储在 output tensordict 中的一个名为"final"
的键下。接着,此转换会读取该 final 数据,并将其与由于实际重置而写入的观察值进行交换,然后将重置输出保存在私有容器中。生成的数据真实反映了 step 的输出。此类适用于从 gym 0.13 到最新的 gymnasium 版本。
注意
Gym 版本 < 0.22 未返回最终观察值。对于这些版本,我们仅用 NaN 填充下一个观察值(因为它是丢失的),并在下一步进行交换。
然后,在调用 env.reset 时,保存的数据将被写回其应有的位置(并且 reset 无效)。
当在创建 wrapper 时使用了异步环境时,此转换会自动附加到 gym 环境。
- 参数:
final_name (str, optional) – dict 中 final 观察值的名称。默认为 “final”。
missing_obs_value (Any, optional) – 用于填充缺失的最后一个观察值的默认占位符。默认为 np.nan。
注意
总的来说,此类不应被直接处理。当向量化环境放置在
GymWrapper
中时,它会被创建。- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不在任何 env.step 中被调用。但是,它在
sample()
中被调用。注意
forward
还可以使用dispatch
通过将参数名称强制转换为键来处理常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范