ConditionalPolicySwitch¶
- class torchrl.envs.transforms.ConditionalPolicySwitch(policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool])[source]¶
一个根据指定条件有条件地在策略之间切换的转换。
此转换会评估环境中 step 方法返回的数据上的一个条件。如果满足条件,它会将指定的策略应用于数据。否则,数据将按原样返回。这对于需要根据特定标准应用不同策略的场景非常有用,例如在游戏中交替回合。
- 参数:
policy (Callable[[TensorDictBase], TensorDictBase]) – 满足条件时要应用的策略。它应该是一个可调用对象,接受一个 TensorDictBase 并返回一个 TensorDictBase。
condition (Callable[[TensorDictBase], bool]) – 一个可调用对象,它接受一个 TensorDictBase 并返回一个布尔值或张量,指示是否应应用该策略。
警告
此转换必须有一个父环境。
注意
理想情况下,它应该是堆栈中的最后一个转换。如果策略需要转换后的数据(例如,图像),并且此转换应用于这些转换之前,那么策略将不会收到转换后的数据。
示例
>>> import torch >>> from tensordict.nn import TensorDictModule as Mod >>> >>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter >>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than >>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main >>> # policy which does a left action. >>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True) >>> >>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"]) >>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"]) >>> >>> cond = lambda td: td.get("observation")[..., 0] >= 0 >>> >>> env = base_env.append_transform( ... Compose( ... # We use two step counters to show that one counts the global steps, whereas the other ... # only counts the steps where the main policy is executed ... StepCounter(step_count_key="step_count_total"), ... ConditionalPolicySwitch(condition=cond, policy=policy_switch), ... StepCounter(step_count_key="step_count_main"), ... ) ... ) >>> >>> env.set_seed(0) >>> torch.manual_seed(0) >>> >>> r = env.rollout(100, policy=policy) >>> print("action", r["action"]) action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) >>> print("obs", r["observation"]) obs tensor([[ 0.0322, -0.1540, 0.0111, 0.3190], [ 0.0299, -0.1544, 0.0181, 0.3280], [ 0.0276, -0.1550, 0.0255, 0.3414], [ 0.0253, -0.1558, 0.0334, 0.3596], [ 0.0230, -0.1569, 0.0422, 0.3828], [ 0.0206, -0.1582, 0.0519, 0.4117], [ 0.0181, -0.1598, 0.0629, 0.4469], [ 0.0156, -0.1617, 0.0753, 0.4891], [ 0.0130, -0.1639, 0.0895, 0.5394], [ 0.0104, -0.1665, 0.1058, 0.5987], [ 0.0076, -0.1696, 0.1246, 0.6685], [ 0.0047, -0.1732, 0.1463, 0.7504], [ 0.0016, -0.1774, 0.1715, 0.8459], [-0.0020, 0.0150, 0.1884, 0.6117], [-0.0017, 0.2071, 0.2006, 0.3838]]) >>> print("obs'", r["next", "observation"]) obs' tensor([[ 0.0299, -0.1544, 0.0181, 0.3280], [ 0.0276, -0.1550, 0.0255, 0.3414], [ 0.0253, -0.1558, 0.0334, 0.3596], [ 0.0230, -0.1569, 0.0422, 0.3828], [ 0.0206, -0.1582, 0.0519, 0.4117], [ 0.0181, -0.1598, 0.0629, 0.4469], [ 0.0156, -0.1617, 0.0753, 0.4891], [ 0.0130, -0.1639, 0.0895, 0.5394], [ 0.0104, -0.1665, 0.1058, 0.5987], [ 0.0076, -0.1696, 0.1246, 0.6685], [ 0.0047, -0.1732, 0.1463, 0.7504], [ 0.0016, -0.1774, 0.1715, 0.8459], [-0.0020, 0.0150, 0.1884, 0.6117], [-0.0017, 0.2071, 0.2006, 0.3838], [ 0.0105, 0.2015, 0.2115, 0.5110]]) >>> print("total step count", r["step_count_total"].squeeze()) total step count tensor([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27]) >>> print("total step with main policy", r["step_count_main"].squeeze()) total step with main policy tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
- forward(tensordict: TensorDictBase) Any [source]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不会在任何时候在 env.step 中调用。但是,它会在
sample()
中被调用。注意
forward
还可以使用dispatch
结合常规关键字参数来将参数名称转换为键。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.