VecNorm¶
- class torchrl.envs.transforms.VecNorm(*args, **kwargs)[source]¶
用于 torchrl 环境的移动平均归一化层。
警告
此类将被
VecNormV2
取代,并将在 v0.10 中被替换。您可以将 new_api 参数用于适配,或者从 torchrl.envs 导入 VecNormV2 类。VecNorm 跟踪数据集的汇总统计信息,以便在运行时对其进行标准化。如果转换处于“eval”模式,则不会更新运行统计信息。
如果多个进程运行相似的环境,可以传递一个位于共享内存中的 TensorDictBase 实例:如果这样做,每次查询归一化层时,它将更新共享相同引用的所有进程的值。
要在推理时使用 VecNorm 并避免使用新观察更新值,应将此层替换为
to_observation_norm()
。这将提供一个静态版本的 VecNorm,在源转换更新时不会被更新。要获取 VecNorm 层的冻结副本,请参见frozen_copy()
。- 参数:
in_keys (sequence of NestedKey, optional) – 要更新的键。默认值:[“observation”, “reward”]
out_keys (sequence of NestedKey, optional) – 目标键。默认为
in_keys
。shared_td (TensorDictBase, optional) – 一个共享的 tensordict,包含转换的键。
lock (mp.Lock) – 一个用于防止进程之间发生竞态条件的锁。默认为 None(在初始化期间创建锁)。
decay (number, optional) – 移动平均的衰减率。默认值:0.99
eps (number, optional) – 运行标准差的下界(用于数值下溢)。默认值为 1e-4。
shapes (List[torch.Size], optional) – 如果提供,表示每个 in_keys 的形状。其长度必须与
in_keys
的长度匹配。每个形状必须匹配相应条目的尾部维度。否则,条目的特征维度(即不属于 tensordict 批次大小的所有维度)将被视为特征维度。new_api (bool or None, optional) – 如果为
True
,将返回 VecNormV2 的一个实例。如果未传递,将引发警告。默认为False
。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> t = VecNorm(decay=0.9) >>> env = GymEnv("Pendulum-v0") >>> env = TransformedEnv(env, t) >>> tds = [] >>> for _ in range(1000): ... td = env.rand_step() ... if td.get("done"): ... _ = env.reset() ... tds += [td] >>> tds = torch.stack(tds, 0) >>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all()) tensor(True) >>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all()) tensor(True)
为跨进程的归一化创建共享的 tensordict。
- 参数:
env (EnvBase) – 用于创建 tensordict 的示例环境
keys (sequence of NestedKey, optional) – 需要归一化的键。默认为 [“next”, “reward”]
memmap (bool) – 如果为
True
,则结果的 tensordict 将被转换为内存映射(使用 memmap_())。否则,tensordict 将被放入共享内存。
- 返回:
一个要发送到每个进程的共享内存。
示例
>>> from torch import multiprocessing as mp >>> queue = mp.Queue() >>> env = make_env() >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, ... ["next", "reward"]) >>> assert td_shared.is_shared() >>> queue.put(td_shared) >>> # on workers >>> v = VecNorm(shared_td=queue.get()) >>> env = TransformedEnv(make_env(), v)
- forward(next_tensordict: TensorDictBase) TensorDictBase ¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法不会在任何时候在 env.step 内调用。但是,它会在
sample()
内被调用。注意
forward
也通过使用dispatch
将参数名称转换为键来使用常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- freeze() VecNorm [source]¶
冻结 VecNorm,避免在调用时更新统计信息。
参见
unfreeze()
。
- get_extra_state() OrderedDict [source]¶
返回要包含在模块 state_dict 中的任何额外状态。
如果需要存储额外的状态,请实现此方法和相应的
set_extra_state()
。在构建模块的 state_dict() 时会调用此函数。请注意,为了确保 state_dict 的序列化正常工作,额外的状态应该是可 Pickled 的。我们仅为序列化 Tensors 提供向后兼容性保证;其他对象的序列化 Pickled 形式如果发生变化,可能会破坏向后兼容性。
- 返回:
要存储在模块 state_dict 中的任何额外状态
- 返回类型:
对象
- property loc¶
返回一个用于仿射变换的 loc TensorDict。
- property scale¶
返回一个用于仿射变换的 scale TensorDict。
- set_extra_state(state: OrderedDict) None [source]¶
设置加载的 state_dict 中包含的额外状态。
此函数从
load_state_dict()
调用,用于处理 state_dict 中的任何额外状态。如果需要将额外状态存储在其 state_dict 中,请实现此函数和相应的get_extra_state()
。- 参数:
state (dict) – 来自 state_dict 的额外状态
- property standard_normal¶
由 loc 和 scale 给出的仿射变换是否遵循标准正态方程。
类似于
ObservationNorm
的 standard_normal 属性。始终返回
True
。
- to_observation_norm() Compose | ObservationNorm [source]¶
将 VecNorm 转换为可在推理时使用的 ObservationNorm 类。
可以使用
state_dict()
API 更新ObservationNorm
层。示例
>>> from torchrl.envs import GymEnv, VecNorm >>> vecnorm = VecNorm(in_keys=["observation"]) >>> train_env = GymEnv("CartPole-v1", device=None).append_transform( ... vecnorm) >>> >>> r = train_env.rollout(4) >>> >>> eval_env = GymEnv("CartPole-v1").append_transform( ... vecnorm.to_observation_norm()) >>> print(eval_env.transform.loc, eval_env.transform.scale) >>> >>> r = train_env.rollout(4) >>> # Update entries with state_dict >>> eval_env.transform.load_state_dict( ... vecnorm.to_observation_norm().state_dict()) >>> print(eval_env.transform.loc, eval_env.transform.scale)
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范