TrajCounter¶
- class torchrl.envs.transforms.TrajCounter(out_key: NestedKey = 'traj_count', *, repeats: int | None = None)[source]¶
全局轨迹计数转换。
TrajCounter 可用于计算任何 TorchRL 环境中的轨迹数量(即调用 reset 的次数)。此转换将在单个节点内的多个进程中工作(请参阅下面的注释)。单个转换只能计算与单个 done 状态关联的轨迹,但只要其前缀与计数器键的前缀匹配,就可以接受嵌套的 done 状态。
- 参数:
out_key (NestedKey, optional) – 轨迹计数器的条目名称。默认为
"traj_count"
。
示例
>>> from torchrl.envs import GymEnv, StepCounter, TrajCounter >>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) >>> env = env.append_transform(TrajCounter()) >>> r = env.rollout(18, break_when_any_done=False) # 18 // 6 = 3 trajectories >>> r["next", "traj_count"] tensor([[0], [0], [0], [0], [0], [0], [1], [1], [1], [1], [1], [1], [2], [2], [2], [2], [2], [2]])
注意
可以通过多种方式在工作进程之间共享轨迹计数器,但这通常涉及将环境包装在
EnvCreator
中。否则,在序列化转换时可能会发生错误。计数器将在工作进程之间共享,这意味着在任何给定时间,保证不会有两个环境共享相同的轨迹计数(并且每个(步数-计数, 轨迹-计数)对都是唯一的)。以下是跨进程共享TrajCounter
对象的一些有效方法示例。>>> # Option 1: Create the trajectory counter outside the environment. >>> # This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents. >>> t = TrajCounter() >>> def make_env(max_steps=4, t=t): ... # See CountingEnv in torchrl.test.mocking_classes ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) ... env.transform.transform_observation_spec(env.base_env.observation_spec) ... return env >>> penv = ParallelEnv( ... 2, ... [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)], ... mp_start_method="spawn", ... ) >>> # Option 2: Create the transform within the constructor. >>> # In this scenario, we still need to tell each sub-env what kwarg has to be used. >>> # Both EnvCreator and ParallelEnv offer that possibility. >>> def make_env(max_steps=4): ... t = TrajCounter() ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t) ... env.transform.transform_observation_spec(env.base_env.observation_spec) ... return env >>> make_env_c0 = EnvCreator(make_env) >>> # Create a variant of the env with different kwargs >>> make_env_c1 = make_env_c0.make_variant(max_steps=5) >>> penv = ParallelEnv( ... 2, ... [make_env_c0, make_env_c1], ... mp_start_method="spawn", ... ) >>> # Alternatively, pass the kwargs to the ParallelEnv >>> penv = ParallelEnv( ... 2, ... [make_env_c0, make_env_c0], ... create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}], ... mp_start_method="spawn", ... )
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
读取输入 tensordict,并对选定的键应用转换。
默认情况下,此方法
直接调用
_apply_transform()
。不调用
_step()
或_call()
。
此方法在 env.step 的任何时候都不会被调用。但是,它在
sample()
中被调用。注意
forward
也通过使用dispatch
将参数名称转换为键来处理常规关键字参数。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]¶
将参数和缓冲区从
state_dict
复制到此模块及其子模块。如果
strict
为True
,则state_dict
的键必须与此模块的state_dict()
函数返回的键完全匹配。警告
如果
assign
为True
,则优化器必须在调用load_state_dict
之后创建,除非get_swap_module_params_on_conversion()
为True
。- 参数:
state_dict (dict) – 包含参数和持久 buffer 的字典。
strict (bool, optional) – 是否严格强制
state_dict
中的键与此模块的state_dict()
函数返回的键匹配。默认值:True
assign (bool, optional) – 当设置为
False
时,会保留当前模块中张量的属性,而当设置为True
时,会保留 state dict 中张量的属性。唯一例外是requires_grad
字段默认值: ``False`
- 返回:
- missing_keys 是一个包含任何预期键的 str 列表。
在提供的
state_dict
中缺失的任何键的字符串列表。
- unexpected_keys 是一个包含不匹配的键的 str 列表。
不期望但在提供的
state_dict
中存在的键。
- 返回类型:
NamedTuple
,具有missing_keys
和unexpected_keys
字段
注意
如果参数或缓冲区注册为
None
且其对应的键存在于state_dict
中,load_state_dict()
将引发RuntimeError
。
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]¶
返回一个字典,其中包含对模块整个状态的引用。
参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为
None
的参数和缓冲区不包含在内。注意
返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。
警告
目前
state_dict()
也按顺序接受destination
、prefix
和keep_vars
的位置参数。但是,这正在被弃用,并且将在未来的版本中强制使用关键字参数。警告
请避免使用参数
destination
,因为它不是为最终用户设计的。- 参数:
destination (dict, optional) – 如果提供,模块的状态将更新到字典中,并返回相同的对象。否则,将创建一个
OrderedDict
并返回。默认值:None
。prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
。keep_vars (bool, optional) – 默认情况下,state dict 中返回的
Tensor
已从 autograd 中分离。如果设置为True
,则不会执行分离。默认值:False
。
- 返回:
包含模块整体状态的字典
- 返回类型:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- transform_observation_spec(observation_spec: Composite) Composite [source]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范