快捷方式

BatchSizeTransform

class torchrl.envs.transforms.BatchSizeTransform(*, batch_size: torch.Size | None = None, reshape_fn: Callable[[TensorDictBase], TensorDictBase] | None = None, reset_func: Callable[[TensorDictBase, TensorDictBase], TensorDictBase] | None = None, env_kwarg: bool = False)[源代码]

一个用于修改环境批次大小的转换器。

此转换器有两种不同的用法:可以将其用于为非批次锁定(例如无状态)的环境设置批次大小,以便使用数据收集器进行数据收集。它还可以用于修改环境的批次大小(例如,挤压、解挤压或重塑)。

此转换器将环境批次大小修改为与提供的批次大小匹配。它期望父环境的批次大小可以扩展到提供的批次大小。

关键字参数:
  • batch_size (torch.Size等价物, 可选) – 环境的新批次大小。与 reshape_fn 互斥。

  • reshape_fn (callable, optional) –

    一个用于修改环境批次大小的可调用对象。与 batch_size 互斥。

    注意

    目前支持涉及 reshapeflattenunflattensqueezeunsqueeze 的转换。如果需要其他重塑操作,请在 TorchRL GitHub 上提交功能请求。

  • reset_func (callable, optional) – 一个生成重置 tensordict 的函数。签名必须匹配 Callable[[TensorDictBase, TensorDictBase], TensorDictBase],其中第一个输入参数是调用 reset() 时传递给环境的可选 tensordict,第二个参数是 TransformedEnv.base_env.reset 的输出。如果 env_kwarg=True,它还可以支持可选的 env 关键字参数。

  • env_kwarg (bool, optional) – 如果为 True,则 reset_func 必须支持 env 关键字参数。默认为 False。传递的 env 将是伴随其转换的 env。

示例

>>> # Changing the batch-size with a function
>>> from torchrl.envs import GymEnv
>>> base_env = GymEnv("CartPole-v1")
>>> env = TransformedEnv(base_env, BatchSizeTransform(reshape_fn=lambda data: data.reshape(1, 1)))
>>> env.rollout(4)
>>> # Setting the shape of a stateless environment
>>> class MyEnv(EnvBase):
...     batch_locked = False
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = Composite(observation=Unbounded(3))
...         self.reward_spec = Unbounded(1)
...         self.action_spec = Unbounded(1)
...
...     def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
...         tensordict_batch_size = tensordict.batch_size if tensordict is not None else torch.Size([])
...         result = self.observation_spec.rand(tensordict_batch_size)
...         result.update(self.full_done_spec.zero(tensordict_batch_size))
...         return result
...
...     def _step(
...         self,
...         tensordict: TensorDictBase,
...     ) -> TensorDictBase:
...         result = self.observation_spec.rand(tensordict.batch_size)
...         result.update(self.full_done_spec.zero(tensordict.batch_size))
...         result.update(self.full_reward_spec.zero(tensordict.batch_size))
...         return result
...
...     def _set_seed(self, seed: Optional[int]) -> None:
...         pass
...
>>> env = TransformedEnv(MyEnv(), BatchSizeTransform([5]))
>>> assert env.batch_size == torch.Size([5])
>>> assert env.rollout(10).shape == torch.Size([5, 10])

reset_func 可以创建具有所需批次大小的 tensordict,从而实现精细的 reset 调用。

>>> def reset_func(tensordict, tensordict_reset, env):
...     result = env.observation_spec.rand()
...     result.update(env.full_done_spec.zero())
...     assert result.batch_size != torch.Size([])
...     return result
>>> env = TransformedEnv(MyEnv(), BatchSizeTransform([5], reset_func=reset_func, env_kwarg=True))
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 2]),
    device=None,
    is_shared=False)

此转换器可用于将非批次锁定的环境部署到数据收集器中。

>>> from torchrl.collectors import SyncDataCollector
>>> collector = SyncDataCollector(env, lambda td: env.rand_action(td), frames_per_batch=10, total_frames=-1)
>>> for data in collector:
...     print(data)
...     break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 2]),
    device=None,
    is_shared=False)
>>> collector.shutdown()
forward(next_tensordict: TensorDictBase) TensorDictBase

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法在 env.step 的任何点都不会被调用。但是,它会在 sample() 中被调用。

注意

forward 还支持使用 dispatch 将常规关键字参数用于将参数名称转换为键。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
transform_env_batch_size(batch_size: Size)[源代码]

转换父环境的 batch-size。

transform_input_spec(input_spec: Composite) Composite[源代码]

转换输入规范,使结果规范与转换映射匹配。

参数:

input_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_output_spec(output_spec: Composite) Composite[源代码]

转换输出规范,使结果规范与转换映射匹配。

此方法通常应保持不变。更改应通过 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 来实现。 :param output_spec: 转换前的 spec :type output_spec: TensorSpec

返回:

转换后的预期规范

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源