快捷方式

ConditionalPolicySwitch

class torchrl.envs.transforms.ConditionalPolicySwitch(policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool])[source]

一个根据指定条件有条件地在策略之间切换的转换。

此转换会评估环境中 step 方法返回的数据上的一个条件。如果满足条件,它会将指定的策略应用于数据。否则,数据将按原样返回。这对于需要根据特定标准应用不同策略的场景非常有用,例如在游戏中交替回合。

参数:
  • policy (Callable[[TensorDictBase], TensorDictBase]) – 满足条件时要应用的策略。它应该是一个可调用对象,接受一个 TensorDictBase 并返回一个 TensorDictBase

  • condition (Callable[[TensorDictBase], bool]) – 一个可调用对象,它接受一个 TensorDictBase 并返回一个布尔值或张量,指示是否应应用该策略。

警告

此转换必须有一个父环境。

注意

理想情况下,它应该是堆栈中的最后一个转换。如果策略需要转换后的数据(例如,图像),并且此转换应用于这些转换之前,那么策略将不会收到转换后的数据。

示例

>>> import torch
>>> from tensordict.nn import TensorDictModule as Mod
>>>
>>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter
>>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than
>>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main
>>> # policy which does a left action.
>>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True)
>>>
>>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"])
>>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"])
>>>
>>> cond = lambda td: td.get("observation")[..., 0] >= 0
>>>
>>> env = base_env.append_transform(
...     Compose(
...         # We use two step counters to show that one counts the global steps, whereas the other
...         # only counts the steps where the main policy is executed
...         StepCounter(step_count_key="step_count_total"),
...         ConditionalPolicySwitch(condition=cond, policy=policy_switch),
...         StepCounter(step_count_key="step_count_main"),
...     )
... )
>>>
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>>
>>> r = env.rollout(100, policy=policy)
>>> print("action", r["action"])
action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
>>> print("obs", r["observation"])
obs tensor([[ 0.0322, -0.1540,  0.0111,  0.3190],
        [ 0.0299, -0.1544,  0.0181,  0.3280],
        [ 0.0276, -0.1550,  0.0255,  0.3414],
        [ 0.0253, -0.1558,  0.0334,  0.3596],
        [ 0.0230, -0.1569,  0.0422,  0.3828],
        [ 0.0206, -0.1582,  0.0519,  0.4117],
        [ 0.0181, -0.1598,  0.0629,  0.4469],
        [ 0.0156, -0.1617,  0.0753,  0.4891],
        [ 0.0130, -0.1639,  0.0895,  0.5394],
        [ 0.0104, -0.1665,  0.1058,  0.5987],
        [ 0.0076, -0.1696,  0.1246,  0.6685],
        [ 0.0047, -0.1732,  0.1463,  0.7504],
        [ 0.0016, -0.1774,  0.1715,  0.8459],
        [-0.0020,  0.0150,  0.1884,  0.6117],
        [-0.0017,  0.2071,  0.2006,  0.3838]])
>>> print("obs'", r["next", "observation"])
obs' tensor([[ 0.0299, -0.1544,  0.0181,  0.3280],
        [ 0.0276, -0.1550,  0.0255,  0.3414],
        [ 0.0253, -0.1558,  0.0334,  0.3596],
        [ 0.0230, -0.1569,  0.0422,  0.3828],
        [ 0.0206, -0.1582,  0.0519,  0.4117],
        [ 0.0181, -0.1598,  0.0629,  0.4469],
        [ 0.0156, -0.1617,  0.0753,  0.4891],
        [ 0.0130, -0.1639,  0.0895,  0.5394],
        [ 0.0104, -0.1665,  0.1058,  0.5987],
        [ 0.0076, -0.1696,  0.1246,  0.6685],
        [ 0.0047, -0.1732,  0.1463,  0.7504],
        [ 0.0016, -0.1774,  0.1715,  0.8459],
        [-0.0020,  0.0150,  0.1884,  0.6117],
        [-0.0017,  0.2071,  0.2006,  0.3838],
        [ 0.0105,  0.2015,  0.2115,  0.5110]])
>>> print("total step count", r["step_count_total"].squeeze())
total step count tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27])
>>> print("total step with main policy", r["step_count_main"].squeeze())
total step with main policy tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
forward(tensordict: TensorDictBase) Any[source]

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不会在任何时候在 env.step 中调用。但是,它会在 sample() 中被调用。

注意

forward 还可以使用 dispatch 结合常规关键字参数来将参数名称转换为键。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源