快捷方式

VecNorm

class torchrl.envs.transforms.VecNorm(*args, **kwargs)[source]

用于 torchrl 环境的移动平均归一化层。

警告

此类将弃用,取而代之的是 VecNormV2,并将在 v0.10 中被该类替换。您可以通过使用 new_api 参数或从 torchrl.envs 导入 VecNormV2 类来适应这些更改。

VecNorm 跟踪数据集的摘要统计信息,以便对其进行即时标准化。如果该转换处于“eval”模式,则不会更新运行统计信息。

如果多个进程运行相似的环境,则可以传递一个放置在共享内存中的 TensorDictBase 实例:如果是这样,每次查询归一化层时,它都会更新共享相同引用的所有进程的值。

要在推理时使用 VecNorm 并避免使用新观察值更新值,应将此层替换为 to_observation_norm()。这将提供一个静态版本的 VecNorm,在源转换更新时不会更新。要获取 VecNorm 层的冻结副本,请参阅 frozen_copy()

参数:
  • in_keys (sequence of NestedKey, optional) – 要更新的键。默认为 [“observation”, “reward”]

  • out_keys (sequence of NestedKey, optional) – 目标键。默认为 in_keys

  • shared_td (TensorDictBase, optional) – 一个共享的 tensordict,包含转换的键。

  • lock (mp.Lock) – 一个锁,用于防止进程之间发生竞争条件。默认为 None(在初始化期间创建锁)。

  • decay (number, optional) – 移动平均的衰减率。默认为 0.99

  • eps (number, optional) – 运行标准差的下界(用于数值下溢)。默认为 1e-4。

  • shapes (List[torch.Size], optional) – 如果提供,则表示每个 in_keys 的形状。其长度必须与 in_keys 的长度匹配。每个形状都必须匹配相应条目的最后一个维度。如果不是,则将条目的特征维度(即不属于 tensordict 批次大小的所有维度)视为特征维度。

  • new_api (bool or None, optional) – 如果 True,则将返回 VecNormV2 的实例。如果未传递,将引发警告。默认为 False

示例

>>> from torchrl.envs.libs.gym import GymEnv
>>> t = VecNorm(decay=0.9)
>>> env = GymEnv("Pendulum-v0")
>>> env = TransformedEnv(env, t)
>>> tds = []
>>> for _ in range(1000):
...     td = env.rand_step()
...     if td.get("done"):
...         _ = env.reset()
...     tds += [td]
>>> tds = torch.stack(tds, 0)
>>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all())
tensor(True)
>>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all())
tensor(True)
static build_td_for_shared_vecnorm(env: EnvBase, keys: Sequence[str] | None = None, memmap: bool = False) TensorDictBase[source]

为跨进程归一化创建共享的 tensordict。

参数:
  • env (EnvBase) – 用于创建 tensordict 的示例环境

  • keys (sequence of NestedKey, optional) – 需要归一化的键。默认为 [“next”, “reward”]

  • memmap (bool) – 如果 True,则生成的 tensordict 将被转换为内存映射(使用 memmap_())。否则,tensordict 将被放置在共享内存中。

返回:

一个共享内存,用于发送到每个进程。

示例

>>> from torch import multiprocessing as mp
>>> queue = mp.Queue()
>>> env = make_env()
>>> td_shared = VecNorm.build_td_for_shared_vecnorm(env,
...     ["next", "reward"])
>>> assert td_shared.is_shared()
>>> queue.put(td_shared)
>>> # on workers
>>> v = VecNorm(shared_td=queue.get())
>>> env = TransformedEnv(make_env(), v)
forward(next_tensordict: TensorDictBase) TensorDictBase

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不会在任何时候在 env.step 中调用。但是,它会在 sample() 中调用。

注意

forward 也可以使用 dispatch 将参数名称转换为键,并使用常规关键字参数。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
freeze() VecNorm[source]

冻结 VecNorm,避免在调用时更新统计信息。

请参阅 unfreeze()

frozen_copy() VecNorm[source]

返回一个 Transform 的副本,该副本会跟踪统计信息但不会更新它们。

get_extra_state() OrderedDict[source]

返回要包含在模块 state_dict 中的任何额外状态。

如果您的模块需要存储额外状态,请实现此函数以及相应的 set_extra_state()。在构建模块的 state_dict() 时调用此函数。

注意,为了保证 state_dict 的序列化工作正常,额外状态应该是可被 pickle 的。我们仅为 Tensors 的序列化提供向后兼容性保证;其他对象的序列化形式若发生变化,可能导致向后兼容性中断。

返回:

要存储在模块 state_dict 中的任何额外状态

返回类型:

对象

property loc

返回一个用于仿射变换的 loc 的 TensorDict。

property scale

返回一个用于仿射变换的 scale 的 TensorDict。

set_extra_state(state: OrderedDict) None[source]

设置加载的 state_dict 中包含的额外状态。

此函数从 load_state_dict() 调用,以处理 state_dict 中的任何额外状态。如果您的模块需要在其 state_dict 中存储额外状态,请实现此函数以及相应的 get_extra_state()

参数:

state (dict) – 来自 state_dict 的额外状态

property standard_normal: bool

locscale 提供的仿射变换是否遵循标准正态方程。

类似于 ObservationNorm 的 standard_normal 属性。

始终返回 True

to_observation_norm() Compose | ObservationNorm[source]

将 VecNorm 转换为一个可以在推理时使用的 ObservationNorm 类。

可以使用 state_dict() API 更新 ObservationNorm 层。

示例

>>> from torchrl.envs import GymEnv, VecNorm
>>> vecnorm = VecNorm(in_keys=["observation"])
>>> train_env = GymEnv("CartPole-v1", device=None).append_transform(
...     vecnorm)
>>>
>>> r = train_env.rollout(4)
>>>
>>> eval_env = GymEnv("CartPole-v1").append_transform(
...     vecnorm.to_observation_norm())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
>>>
>>> r = train_env.rollout(4)
>>> # Update entries with state_dict
>>> eval_env.transform.load_state_dict(
...     vecnorm.to_observation_norm().state_dict())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

转换观察规范,使结果规范与转换映射匹配。

参数:

observation_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

unfreeze() VecNorm[source]

解冻 VecNorm。

请参阅 freeze()

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源