快捷方式

CatTensors

class torchrl.envs.transforms.CatTensors(in_keys: Sequence[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[源代码]

将多个键连接成一个张量。

这对于多个键描述单个状态(例如,“observation_position”和“observation_velocity”)特别有用。

参数:
  • in_keys (NestedKey 序列) – 要连接的键。如果为 None(或未提供),则将在转换器首次使用时从父环境中检索键。此行为仅在设置了父项时有效。

  • out_key (NestedKey) – 结果张量的键。

  • dim (int, optional) – 连接发生的维度。默认为 -1

关键字参数:
  • del_keys (bool, optional) – 如果为 True,则输入值在连接后将被删除。默认为 True

  • unsqueeze_if_oor (bool, optional) – 如果为 True,CatTensor 将检查用于连接的张量是否存在指定的维度。如果不存在,张量将沿该维度进行 unsqueeze。默认为 False

  • sort (bool, optional) – 如果为 True,则会在转换器中对键进行排序。否则,将优先使用用户提供的顺序。默认为 True

示例

>>> transform = CatTensors(in_keys=["key1", "key2"])
>>> td = TensorDict({"key1": torch.zeros(1, 1),
...     "key2": torch.ones(1, 1)}, [1])
>>> _ = transform(td)
>>> print(td.get("observation_vector"))
tensor([[0., 1.]])
>>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True)
>>> td = TensorDict({"key1": torch.zeros(1),
...     "key2": torch.ones(1)}, [])
>>> _ = transform(td)
>>> print(td.get("observation_vector").shape)
torch.Size([2, 1])
forward(next_tensordict: TensorDictBase) TensorDictBase

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不在 env.step 的任何点被调用。但是,它在 sample() 中被调用。

注意

forward 也使用 dispatch 与常规关键字参数一起工作,以将参数名称转换为键。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
transform_observation_spec(observation_spec: TensorSpec) TensorSpec[源代码]

转换观察规范,使结果规范与转换映射匹配。

参数:

observation_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源