快捷方式

ConditionalSkip

class torchrl.envs.transforms.ConditionalSkip(cond: Callable[[TensorDict], bool | torch.Tensor])[source]

一个在满足特定条件时跳过 env 步骤的转换。

此转换将 cond(tensordict) 的结果写入传递给 TransformedEnv.base_env._step 方法的 tensordict 的 “_step” 条目中。如果 base_env 不是批处理锁定的(一般而言,它是无状态的),则 tensordict 将被缩减到需要通过步骤的元素。如果它是批处理锁定的(一般而言,它是 stateful 的),则如果 “_step” 中的任何值不为 True,则会完全跳过该步骤。否则,将信任环境会相应地处理 “_step” 信号。

注意

此跳过也会影响修改环境输出的转换,即,如果满足条件,任何将在 step() 返回的 tensordict 上执行的转换都将被跳过。为了缓解这种不良影响,可以将转换后的 env 包装在另一个转换后的 env 中,因为跳过只会影响 ConditionalSkip 转换的头等父项。请参见下面的示例。

参数:

cond (Callable[[TensorDictBase], bool | torch.Tensor]) – 一个用于 tensordict 输入的可调用对象,用于检查是否必须跳过下一个 env 步骤(True = 跳过,False = 执行 env.step)。

示例

>>> import torch
>>>
>>> from torchrl.envs import GymEnv
>>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose
>>>
>>> torch.manual_seed(0)
>>>
>>> base_env = TransformedEnv(
...     GymEnv("Pendulum-v1"),
...     StepCounter(step_count_key="inner_count"),
... )
>>> middle_env = TransformedEnv(
...     base_env,
...     Compose(
...         StepCounter(step_count_key="middle_count"),
...         ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1),
...     ),
...     auto_unwrap=False)  # makes sure that transformed envs are properly wrapped
>>> env = TransformedEnv(
...     middle_env,
...     StepCounter(step_count_key="step_count"),
...     auto_unwrap=False)
>>> env.set_seed(0)
>>>
>>> r = env.rollout(10)
>>> print(r["observation"])
tensor([[-0.9670, -0.2546, -0.9669],
        [-0.9802, -0.1981, -1.1601],
        [-0.9802, -0.1981, -1.1601],
        [-0.9926, -0.1214, -1.5556],
        [-0.9926, -0.1214, -1.5556],
        [-0.9994, -0.0335, -1.7622],
        [-0.9994, -0.0335, -1.7622],
        [-0.9984,  0.0561, -1.7933],
        [-0.9984,  0.0561, -1.7933],
        [-0.9895,  0.1445, -1.7779]])
>>> print(r["inner_count"])
tensor([[0],
        [1],
        [1],
        [2],
        [2],
        [3],
        [3],
        [4],
        [4],
        [5]])
>>> print(r["middle_count"])
tensor([[0],
        [1],
        [1],
        [2],
        [2],
        [3],
        [3],
        [4],
        [4],
        [5]])
>>> print(r["step_count"])
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]])
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法在任何时候都不会在 env.step 中调用。但是,它会在 sample() 中调用。

注意

forward 还可以使用 dispatch 通过将参数名称强制转换为键来处理常规关键字参数。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源