ActionDiscretizer¶
- class torchrl.envs.transforms.ActionDiscretizer(num_intervals: int | torch.Tensor, action_key: NestedKey = 'action', out_action_key: NestedKey = None, sampling=None, categorical: bool = True)[源代码]¶
一个用于离散化连续动作空间的转换。
此转换使得可以使用为离散动作空间设计的算法(如 DQN)来处理具有连续动作空间的环境。
- 参数:
num_intervals (int 或 torch.Tensor) – 动作空间每个元素的离散值数量。如果提供单个整数,则所有动作项都用相同数量的元素进行切片。如果提供张量,则它必须具有与动作空间相同的元素数量(即,
num_intervals
张量的长度必须与动作空间的最后一个维度匹配)。action_key (NestedKey, 可选) – 要使用的动作键。指向父环境的动作(浮点动作)。默认为
"action"
。out_action_key (NestedKey, 可选) – 写入离散动作的键。如果提供
None
,则默认为action_key
的值。如果两个键不匹配,则连续动作action_spec
将从full_action_spec
环境属性移动到full_state_spec
容器,因为为了执行动作,只有离散动作才会被采样。提供out_action_key
可以确保浮点动作可供记录。sampling (ActionDiscretizer.SamplingStrategy, 可选) –
ActionDiscretizer.SamplingStrategy
IntEnum
对象(MEDIAN
、LOW
、HIGH
或RANDOM
)的元素。指示如何在提供的区间内采样连续动作。categorical (bool, 可选) – 如果为
False
,则使用独热编码。默认为True
。
示例
>>> from torchrl.envs import GymEnv, check_env_specs >>> import torch >>> base_env = GymEnv("HalfCheetah-v4") >>> num_intervals = torch.arange(5, 11) >>> categorical = True >>> sampling = ActionDiscretizer.SamplingStrategy.MEDIAN >>> t = ActionDiscretizer( ... num_intervals=num_intervals, ... categorical=categorical, ... sampling=sampling, ... out_action_key="action_disc", ... ) >>> env = base_env.append_transform(t) TransformedEnv( env=GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu), transform=ActionDiscretizer( num_intervals=tensor([ 5, 6, 7, 8, 9, 10]), action_key=action, out_action_key=action_disc,, sampling=0, categorical=True)) >>> check_env_specs(env) >>> # Produce a rollout >>> r = env.rollout(4) >>> print(r) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.float32, is_shared=False), action_disc: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False) >>> assert r["action"].dtype == torch.float >>> assert r["action_disc"].dtype == torch.int64 >>> assert (r["action"] < base_env.action_spec.high).all() >>> assert (r["action"] > base_env.action_spec.low).all()
- inv(tensordict)[源代码]¶
读取输入 tensordict,并对选定的键应用逆变换。
默认情况下,此方法
直接调用
_inv_apply_transform()
。不调用
_inv_call()
。
注意
inv
也通过使用dispatch
将参数名称转换为键来处理常规关键字参数。注意
inv
由extend()
调用。
- transform_input_spec(input_spec)[源代码]¶
转换输入规范,使结果规范与转换映射匹配。
- 参数:
input_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范