LogValidationReward¶
- class torchrl.trainers.LogValidationReward(*, record_interval: int, record_frames: int, frame_skip: int = 1, policy_exploration: TensorDictModule, environment: EnvBase = None, exploration_type: ExplorationType = InteractionType.RANDOM, log_keys: list[str | tuple[str]] | None = None, out_keys: dict[str | tuple[str], str] | None = None, suffix: str | None = None, log_pbar: bool = False, recorder: EnvBase = None)[源代码]¶
Recorder hook for
Trainer
。- 参数:
record_interval (int) – testing 时两次调用 recorder 的总优化步数。
record_frames (int) – testing 时要记录的帧数。
frame_skip (int) – 环境中使用的 frame_skip。让 trainer 知道每次迭代跳过的帧数很重要,否则帧数可能会被低估。最后,为了比较具有不同 frame_skip 的不同运行,必须对帧数和奖励进行归一化。默认为
1
。policy_exploration (ProbabilisticTDModule) –
用于
更新探索噪声计划;
在 recorder 上测试策略的策略实例。
鉴于此实例应同时用于探索和渲染策略的性能,应通过调用 set_exploration_type(ExplorationType.DETERMINISTIC) 上下文管理器来关闭探索行为。
environment (EnvBase) – 用于 testing 的环境实例。
exploration_type (ExplorationType, optional) – 用于策略的探索模式。默认情况下,不使用探索,并且使用的值为
ExplorationType.DETERMINISTIC
。设置为ExplorationType.RANDOM
以启用探索log_keys (sequence of str or tuples or str, optional) – 要在 tensordict 中读取以进行记录的键。默认为
[("next", "reward")]
。out_keys (Dict[str, str], optional) – 一个将
log_keys
映射到其在日志中的名称的字典。默认为{("next", "reward"): "r_evaluation"}
。suffix (str, optional) – 要录制的视频的后缀。
log_pbar (bool, optional) – 如果为
True
,则奖励值将在进度条上记录。默认为 False。
- register(trainer: Trainer, name: str = 'recorder')[源代码]¶
Registers the hook in the trainer at a default location.
- 参数:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
注意
To register the hook at another location than the default, use
register_op()
.