ConditionalSkip¶
- class torchrl.envs.transforms.ConditionalSkip(cond: Callable[[TensorDict], bool | torch.Tensor])[source]¶
一个在满足特定条件时跳过 env 步骤的转换。
此转换将 cond(tensordict) 的结果写入传递给 TransformedEnv.base_env._step 方法的 tensordict 的 “_step” 条目中。如果 base_env 不是批处理锁定的(一般而言,它是无状态的),则 tensordict 将被缩减到需要通过步骤的元素。如果它是批处理锁定的(一般而言,它是 stateful 的),则如果 “_step” 中的任何值不为
True
,则会完全跳过该步骤。否则,将信任环境会相应地处理 “_step” 信号。注意
此跳过也会影响修改环境输出的转换,即,如果满足条件,任何将在
step()
返回的 tensordict 上执行的转换都将被跳过。为了缓解这种不良影响,可以将转换后的 env 包装在另一个转换后的 env 中,因为跳过只会影响ConditionalSkip
转换的头等父项。请参见下面的示例。- 参数:
cond (Callable[[TensorDictBase], bool | torch.Tensor]) – 一个用于 tensordict 输入的可调用对象,用于检查是否必须跳过下一个 env 步骤(True = 跳过,False = 执行 env.step)。
示例
>>> import torch >>> >>> from torchrl.envs import GymEnv >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose >>> >>> torch.manual_seed(0) >>> >>> base_env = TransformedEnv( ... GymEnv("Pendulum-v1"), ... StepCounter(step_count_key="inner_count"), ... ) >>> middle_env = TransformedEnv( ... base_env, ... Compose( ... StepCounter(step_count_key="middle_count"), ... ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1), ... ), ... auto_unwrap=False) # makes sure that transformed envs are properly wrapped >>> env = TransformedEnv( ... middle_env, ... StepCounter(step_count_key="step_count"), ... auto_unwrap=False) >>> env.set_seed(0) >>> >>> r = env.rollout(10) >>> print(r["observation"]) tensor([[-0.9670, -0.2546, -0.9669], [-0.9802, -0.1981, -1.1601], [-0.9802, -0.1981, -1.1601], [-0.9926, -0.1214, -1.5556], [-0.9926, -0.1214, -1.5556], [-0.9994, -0.0335, -1.7622], [-0.9994, -0.0335, -1.7622], [-0.9984, 0.0561, -1.7933], [-0.9984, 0.0561, -1.7933], [-0.9895, 0.1445, -1.7779]]) >>> print(r["inner_count"]) tensor([[0], [1], [1], [2], [2], [3], [3], [4], [4], [5]]) >>> print(r["middle_count"]) tensor([[0], [1], [1], [2], [2], [3], [3], [4], [4], [5]]) >>> print(r["step_count"]) tensor([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法在任何时候都不会在 env.step 中调用。但是,它会在
sample()
中调用。注意
forward
还可以使用dispatch
通过将参数名称强制转换为键来处理常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.