快捷方式

LogValidationReward

class torchrl.trainers.LogValidationReward(*, record_interval: int, record_frames: int, frame_skip: int = 1, policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = InteractionType.RANDOM, log_keys: list[str | tuple[str]] | None = None, out_keys: dict[str | tuple[str], str] | None = None, suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None)[源代码]

Recorder hook for Trainer

参数:
  • record_interval (int) – testing 时两次调用 recorder 的总优化步数。

  • record_frames (int) – testing 时要记录的帧数。

  • frame_skip (int) – 环境中使用的 frame_skip。让 trainer 知道每次迭代跳过的帧数很重要,否则帧数可能会被低估。最后,为了比较具有不同 frame_skip 的不同运行,必须对帧数和奖励进行归一化。默认为 1

  • policy_exploration (ProbabilisticTDModule) –

    用于

    1. 更新探索噪声计划;

    2. 在 recorder 上测试策略的策略实例。

    鉴于此实例应同时用于探索和渲染策略的性能,应通过调用 set_exploration_type(ExplorationType.DETERMINISTIC) 上下文管理器来关闭探索行为。

  • environment (EnvBase) – 用于 testing 的环境实例。

  • exploration_type (ExplorationType, optional) – 用于策略的探索模式。默认情况下,不使用探索,并且使用的值为 ExplorationType.DETERMINISTIC。设置为 ExplorationType.RANDOM 以启用探索

  • log_keys (sequence of str or tuples or str, optional) – 要在 tensordict 中读取以进行记录的键。默认为 [("next", "reward")]

  • out_keys (Dict[str, str], optional) – 一个将 log_keys 映射到其在日志中的名称的字典。默认为 {("next", "reward"): "r_evaluation"}

  • suffix (str, optional) – 要录制的视频的后缀。

  • log_pbar (bool, optional) – 如果为 True,则奖励值将在进度条上记录。默认为 False

register(trainer: Trainer, name: str = 'recorder')[源代码]

Registers the hook in the trainer at a default location.

参数:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

注意

To register the hook at another location than the default, use register_op().

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源