快捷方式

RemoveEmptySpecs

class torchrl.envs.transforms.RemoveEmptySpecs(in_keys: Sequence[NestedKey] = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[源代码]

移除环境中的空 Spec 和内容。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded, Composite,         ...     Categorical
>>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs
>>>
>>>
>>> class DummyEnv(EnvBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(*args, **kwargs)
...         self.observation_spec = Composite(
...             observation=UnboundedContinuous((*self.batch_size, 3)),
...             other=Composite(
...                 another_other=Composite(shape=self.batch_size),
...                 shape=self.batch_size,
...             ),
...             shape=self.batch_size,
...         )
...         self.action_spec = UnboundedContinuous((*self.batch_size, 3))
...         self.done_spec = Categorical(
...             2, (*self.batch_size, 1), dtype=torch.bool
...         )
...         self.full_done_spec["truncated"] = self.full_done_spec[
...             "terminated"].clone()
...         self.reward_spec = Composite(
...             reward=UnboundedContinuous(*self.batch_size, 1),
...             other_reward=Composite(shape=self.batch_size),
...             shape=self.batch_size
...             )
...
...     def _reset(self, tensordict):
...         return self.observation_spec.rand().update(self.full_done_spec.zero())
...
...     def _step(self, tensordict):
...         return TensorDict(
...             {},
...             batch_size=[]
...         ).update(self.observation_spec.rand()).update(
...             self.full_done_spec.zero()
...             ).update(self.full_reward_spec.rand())
...
...     def _set_seed(self, seed) -> None:
...         pass
>>>
>>>
>>> base_env = DummyEnv()
>>> print(base_env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                other: TensorDict(
                    fields={
                        another_other: TensorDict(
                            fields={
                            },
                            batch_size=torch.Size([2]),
                            device=cpu,
                            is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=cpu,
                    is_shared=False),
                other_reward: TensorDict(
                    fields={
                    },
                    batch_size=torch.Size([2]),
                    device=cpu,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> check_env_specs(base_env)
>>> env = TransformedEnv(base_env, RemoveEmptySpecs())
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
check_env_specs(env)
forward(next_tensordict: TensorDictBase) TensorDictBase

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不会在任何时候被 env.step 调用。但是,它会被 sample() 调用。

注意

forward 也可以使用 dispatch 通过将参数名称映射到键来处理常规关键字参数。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
transform_input_spec(input_spec: TensorSpec) TensorSpec[源代码]

转换输入规范,使结果规范与转换映射匹配。

参数:

input_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_output_spec(output_spec: Composite) Composite[源代码]

转换输出规范,使结果规范与转换映射匹配。

此方法通常应保持不变。更改应通过 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 实现。 :param output_spec: transform 之前的 spec :type output_spec: TensorSpec

返回:

转换后的预期规范

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源