TimeMaxPool¶
- class torchrl.envs.transforms.TimeMaxPool(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, T: int = 1, reset_key: NestedKey | None = None)[源码]¶
取最后 T 个观测值在每个位置上的最大值。
此转换会在最后一个 T 时间步内,对所有 in_keys 张量中的每个位置取最大值。
- 参数:
in_keys (NestedKey 序列, 可选) – 将应用 max pool 的输入键。如果为空,则默认为“observation”。
out_keys (NestedKey 序列, 可选) – 将写入输出的输出键。如果为空,则默认为 in_keys。
T (int, 可选) – 应用 max pooling 的时间步数。
reset_key (NestedKey, 可选) – 要用作部分重置指示器的重置键。必须是唯一的。如果未提供,则默认为父环境的唯一重置键(如果只有一个),否则引发异常。
示例
>>> from torchrl.envs import GymEnv >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10)) >>> torch.manual_seed(0) >>> env.set_seed(0) >>> rollout = env.rollout(10) >>> print(rollout["observation"]) # values should be increasing up until the 10th step tensor([[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0216, 0.0000], [ 0.0000, 0.1149, 0.0000], [ 0.0000, 0.1990, 0.0000], [ 0.0000, 0.2749, 0.0000], [ 0.0000, 0.3281, 0.0000], [-0.9290, 0.3702, -0.8978]])
注意
TimeMaxPool
目前仅支持根目录下的done
信号。嵌套的done
(如 MARL 设置中发现的)目前不受支持。如果需要此功能,请在 TorchRL 仓库中提出一个 issue。- forward(tensordict: TensorDictBase) TensorDictBase [源码]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不会在任何时候在 env.step 中调用。但是,它会在
sample()
中调用。注意
forward
也可以使用dispatch
将参数名称转换为键,并使用常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [源码]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范