快捷方式

Reward2GoTransform

class torchrl.envs.transforms.Reward2GoTransform(gamma: float | torch.Tensor | None = 1.0, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, done_key: NestedKey | None = 'done')[源代码]

根据剧集奖励和折扣因子计算“剧集未来奖励”。

由于 Reward2GoTransform 只是一个反向变换,因此 in_keys 将直接用于 in_keys_inv。剧集未来奖励只能在剧集结束后计算。因此,此变换应该应用于回放缓冲区,而不是收集器或环境内部。

参数:
  • gamma (float 或 torch.Tensor) – 折扣因子。默认为 1.0。

  • in_keys (NestedKey 序列) – 要重命名的条目。如果未提供,则默认为 ("next", "reward")

  • out_keys (NestedKey 序列) – 要重命名的条目。如果未提供,则默认为 in_keys 的值。

  • done_key (NestedKey) – done 条目。默认为 "done"

  • truncated_key (NestedKey) – truncated 条目。默认为 "truncated"。如果未找到 truncated 条目,将只使用 "done"

示例

>>> # Using this transform as part of a replay buffer
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
>>> torch.manual_seed(0)
>>> r2g = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"])
>>> rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=r2g)
>>> batch, timesteps = 4, 5
>>> done = torch.zeros(batch, timesteps, 1, dtype=torch.bool)
>>> for i in range(batch):
...     while not done[i].any():
...         done[i] = done[i].bernoulli_(0.1)
>>> reward = torch.ones(batch, timesteps, 1)
>>> td = TensorDict(
...     {"next": {"done": done, "reward": reward}},
...     [batch, timesteps],
... )
>>> rb.extend(td)
>>> sample = rb.sample(1)
>>> print(sample["next", "reward"])
tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.]]])
>>> print(sample["reward_to_go"])
tensor([[[4.9010],
         [3.9404],
         [2.9701],
         [1.9900],
         [1.0000]]])

也可以直接将此变换与收集器一起使用:请确保附加变换的 inv 方法。

示例

>>> from torchrl.envs.utils import RandomPolicy        >>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs.libs.gym import GymEnv
>>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"])
>>> env = GymEnv("Pendulum-v1")
>>> collector = SyncDataCollector(
...     env,
...     RandomPolicy(env.action_spec),
...     frames_per_batch=200,
...     total_frames=-1,
...     postproc=t.inv
... )
>>> for data in collector:
...     break
>>> print(data)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        reward_to_go: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([200]),
    device=cpu,
    is_shared=False)

将此变换用作环境的一部分将引发异常

示例

>>> t = Reward2GoTransform(gamma=0.99)
>>> TransformedEnv(GymEnv("Pendulum-v1"), t)  # crashes

注意

在存在多个 done 条目的设置中,应为每个 done-reward 对构建一个单独的 Reward2GoTransform

forward(tensordict: TensorDictBase) TensorDictBase[源代码]

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不会在任何时候在 env.step 中调用。但是,它会在 sample() 中调用。

注意

forward 也支持使用 dispatch 将常规关键字参数的参数名称转换为键。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源