PixelRenderTransform¶
- torchrl.record.PixelRenderTransform(out_keys: list[NestedKey] = None, preproc: Callable[[np.ndarray | torch.Tensor], np.ndarray | torch.Tensor] = None, as_non_tensor: bool | None = None, render_method: str = 'render', pass_tensordict: bool = False, **kwargs) None [来源]¶
一个调用父环境的 render 方法并将像素观察注册到 tensordict 中的转换。
此转换提供了一种替代方法,用于在实例化提供渲染的环境时,当渲染很昂贵,或者当
from_pixels
未实现时,用于from_pixels
语法糖。它可以用于单个环境或批处理环境。- 参数:
out_keys (List[NestedKey] 或 Nested) – 用于注册像素观测值的键列表。
preproc (Callable, optional) – 一个预处理函数。可用于重塑观测值,或应用任何其他使其能够注册到输出数据中的转换。
as_non_tensor (bool, optional) – 如果为
True
,则数据将作为NonTensorData
写入,从而放宽形状要求。如果未提供,则会根据输入数据类型和形状自动推断。render_method (str, optional) – 渲染方法的名称。默认为
"render"
。pass_tensordict (bool, optional) – 如果为
True
,则输入 tensordict 将传递给渲染方法。这使得无状态环境也能进行渲染。默认为False
。**kwargs – 传递给渲染函数的其他关键字参数(例如
mode="rgb_array"
)。
示例
>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator >>> from torchrl.record.loggers import CSVLogger >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder >>> >>> def make_env(): >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") >>> env = env.append_transform(PixelRenderTransform()) >>> return env >>> >>> if __name__ == "__main__": ... logger = CSVLogger("dummy", video_format="mp4") ... ... env = ParallelEnv(4, EnvCreator(make_env)) ... ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) ... env.rollout(3) ... ... check_env_specs(env) ... ... r = env.rollout(30) ... print(env) ... env.transform.dump() ... env.close()
当批处理环境
render()
返回单个图像时,也可以使用此转换示例
>>> from torchrl.envs import check_env_specs >>> from torchrl.envs.libs.vmas import VmasEnv >>> from torchrl.record.loggers import CSVLogger >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder >>> >>> env = VmasEnv( ... scenario="flocking", ... num_envs=32, ... continuous_actions=True, ... max_steps=200, ... device="cpu", ... seed=None, ... # Scenario kwargs ... n_agents=5, ... ) >>> >>> logger = CSVLogger("dummy", video_format="mp4") >>> >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy())) >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) >>> >>> check_env_specs(env) >>> >>> r = env.rollout(30) >>> env.transform[-1].dump()
可以使用
switch()
方法禁用该转换,该方法将打开渲染(如果其关闭)或关闭渲染(如果其打开)(也可以传递一个参数来控制此行为)。由于转换是Module
实例,因此可以使用apply()
来控制此行为。>>> def switch(module): ... if isinstance(module, PixelRenderTransform): ... module.switch() >>> env.apply(switch)