EnvCreator¶
- class torchrl.envs.EnvCreator(create_env_fn: Callable[..., EnvBase], create_env_kwargs: dict | None = None, share_memory: bool = True, **kwargs)[源代码]¶
环境创建者类。
EnvCreator 是一个通用的环境创建者类,可以在多进程环境中创建环境时替代 lambda 函数。如果需要在子进程中创建的环境与主进程共享信息(例如,用于 VecNorm 转换),EnvCreator 将将 tensordict 的指针传递到共享内存中,以便所有进程都同步。
- 参数:
create_env_fn (callable) – 一个返回 EnvBase 实例的可调用对象。
create_env_kwargs (dict, optional) – env 创建者的关键字参数。
share_memory (bool, optional) – 如果为 False,则环境产生的 tensordict 不会放置在共享内存中。
**kwargs – 在构造期间传递给环境的其他关键字参数。
示例
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations matches on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])
- make_variant(**kwargs) EnvCreator [源代码]¶
创建 EnvCreator 的一个变体,指向相同的底层元数据,但在构造期间使用不同的关键字参数。
这对于共享状态的转换(如
TrajCounter
)可能很有用。示例
>>> from torchrl.envs import GymEnv >>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1") >>> env_creator_cartpole = env_creator_pendulum.make_variant(env_name="CartPole-v1")