快捷方式

TrajCounter

class torchrl.envs.transforms.TrajCounter(out_key: NestedKey = 'traj_count', *, repeats: int | None = None)[source]

全局轨迹计数转换。

TrajCounter 可用于计算任何 TorchRL 环境中的轨迹数量(即调用 reset 的次数)。此转换将在单个节点内的多个进程中工作(请参阅下面的注释)。单个转换只能计算与单个 done 状态关联的轨迹,但只要其前缀与计数器键的前缀匹配,就可以接受嵌套的 done 状态。

参数:

out_key (NestedKey, optional) – 轨迹计数器的条目名称。默认为 "traj_count"

示例

>>> from torchrl.envs import GymEnv, StepCounter, TrajCounter
>>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = env.append_transform(TrajCounter())
>>> r = env.rollout(18, break_when_any_done=False)  # 18 // 6 = 3 trajectories
>>> r["next", "traj_count"]
tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])

注意

可以通过多种方式在工作进程之间共享轨迹计数器,但这通常涉及将环境包装在 EnvCreator 中。否则,在序列化转换时可能会发生错误。计数器将在工作进程之间共享,这意味着在任何给定时间,保证不会有两个环境共享相同的轨迹计数(并且每个(步数-计数, 轨迹-计数)对都是唯一的)。以下是跨进程共享 TrajCounter 对象的一些有效方法示例。

>>> # Option 1: Create the trajectory counter outside the environment.
>>> #  This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents.
>>> t = TrajCounter()
>>> def make_env(max_steps=4, t=t):
...     # See CountingEnv in torchrl.test.mocking_classes
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone())
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> penv = ParallelEnv(
...     2,
...     [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)],
...     mp_start_method="spawn",
... )
>>> # Option 2: Create the transform within the constructor.
>>> #  In this scenario, we still need to tell each sub-env what kwarg has to be used.
>>> #  Both EnvCreator and ParallelEnv offer that possibility.
>>> def make_env(max_steps=4):
...     t = TrajCounter()
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t)
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> make_env_c0 = EnvCreator(make_env)
>>> # Create a variant of the env with different kwargs
>>> make_env_c1 = make_env_c0.make_variant(max_steps=5)
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c1],
...     mp_start_method="spawn",
... )
>>> # Alternatively, pass the kwargs to the ParallelEnv
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c0],
...     create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}],
...     mp_start_method="spawn",
... )
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法在 env.step 的任何时候都不会被调用。但是,它在 sample() 中被调用。

注意

forward 也通过使用 dispatch 将参数名称转换为键来处理常规关键字参数。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]

将参数和缓冲区从 state_dict 复制到此模块及其子模块。

如果 strictTrue,则 state_dict 的键必须与此模块的 state_dict() 函数返回的键完全匹配。

警告

如果 assignTrue,则优化器必须在调用 load_state_dict 之后创建,除非 get_swap_module_params_on_conversion()True

参数:
  • state_dict (dict) – 包含参数和持久 buffer 的字典。

  • strict (bool, optional) – 是否严格强制 state_dict 中的键与此模块的 state_dict() 函数返回的键匹配。默认值: True

  • assign (bool, optional) – 当设置为 False 时,会保留当前模块中张量的属性,而当设置为 True 时,会保留 state dict 中张量的属性。唯一例外是 requires_grad 字段 默认值: ``False`

返回:

  • missing_keys 是一个包含任何预期键的 str 列表。

    在提供的 state_dict 中缺失的任何键的字符串列表。

  • unexpected_keys 是一个包含不匹配的键的 str 列表。

    不期望但在提供的 state_dict 中存在的键。

返回类型:

NamedTuple,具有 missing_keysunexpected_keys 字段

注意

如果参数或缓冲区注册为 None 且其对应的键存在于 state_dict 中,load_state_dict() 将引发 RuntimeError

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]

返回一个字典,其中包含对模块整个状态的引用。

参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为 None 的参数和缓冲区不包含在内。

注意

返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。

警告

目前 state_dict() 也按顺序接受 destinationprefixkeep_vars 的位置参数。但是,这正在被弃用,并且将在未来的版本中强制使用关键字参数。

警告

请避免使用参数 destination,因为它不是为最终用户设计的。

参数:
  • destination (dict, optional) – 如果提供,模块的状态将更新到字典中,并返回相同的对象。否则,将创建一个 OrderedDict 并返回。默认值: None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 默认情况下,state dict 中返回的 Tensor 已从 autograd 中分离。如果设置为 True,则不会执行分离。默认值: False

返回:

包含模块整体状态的字典

返回类型:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
transform_observation_spec(observation_spec: Composite) Composite[source]

转换观察规范,使结果规范与转换映射匹配。

参数:

observation_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源