快捷方式

TorchRL 配置系统

TorchRL 提供了一个强大的配置系统,该系统构建在 Hydra 之上,使您可以轻松配置和运行强化学习实验。该系统使用基于数据类的结构化配置,这些配置可以进行组合、覆盖和扩展。

使用配置系统的优点包括: - 快速轻松上手:提供您的任务,让系统处理其余部分 - 一次性概览可用选项及其默认值:python sota-implementations/ppo_trainer/train.py --help 将显示所有可用选项及其默认值 - 易于覆盖和扩展:您可以覆盖配置文件中的任何选项,也可以使用自己的自定义配置扩展配置文件 - 易于共享和复现:您可以与他人共享配置文件,他们只需运行相同的命令即可复现您的结果。 - 易于版本控制:您可以轻松地对配置文件进行版本控制。

快速入门示例

让我们从一个创建 Gym 环境的简单示例开始。这是一个最小的配置文件

# config.yaml
defaults:
  - env@training_env: gym

training_env:
  env_name: CartPole-v1

此配置有两个主要部分

1. defaults **部分**

defaults 部分告诉 Hydra 要包含哪些配置组。在这种情况下

  • env@training_env: gym 意味着“使用 'env' 组中的 'gym' 配置来作为 'training_env' 目标”

这等同于包含一个预定义的 Gym 环境配置,该配置设置了正确的 target 类和默认参数。

2. 配置覆盖

training_env 部分允许您覆盖或指定所选配置的参数

  • env_name: CartPole-v1 设置了特定的环境名称

配置类别和组

TorchRL 使用 @ 语法将配置组织成多个类别,以实现目标化配置

  • env@<target>:环境配置(Gym、DMControl、Brax 等)以及批处理环境

  • transform@<target>:转换配置(观察/奖励处理)

  • model@<target>:模型配置(策略和价值网络)

  • network@<target>:神经网络配置(MLP、ConvNet)

  • collector@<target>:数据收集配置

  • replay_buffer@<target>:回放缓冲区配置

  • storage@<target>:存储后端配置

  • sampler@<target>:采样策略配置

  • writer@<target>:写入器策略配置

  • trainer@<target>:训练循环配置

  • optimizer@<target>:优化器配置

  • loss@<target>:损失函数配置

  • logger@<target>:日志记录配置

@<target> 语法允许您将配置分配到配置结构中的特定位置。

更复杂的示例:带转换的并行环境

这是一个更复杂的示例,它创建了一个并行环境,并为每个工作进程应用多个转换

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: transformed_env
  - env@training_env.create_env_fn.base_env: gym
  - transform@training_env.create_env_fn.transform: compose
  - transform@transform0: noop_reset
  - transform@transform1: step_counter

# Transform configurations
transform0:
  noops: 30
  random: true

transform1:
  max_steps: 200
  step_count_key: "step_count"

# Environment configuration
training_env:
  num_workers: 4
  create_env_fn:
    base_env:
      env_name: Pendulum-v1
    transform:
      transforms:
        - ${transform0}
        - ${transform1}
    _partial_: true

此配置创建的内容

此配置构建了一个**具有 4 个工作进程的并行环境**,其中每个工作进程运行一个**应用了两个转换的 Pendulum-v1 环境**

  1. 并行环境结构: - batched_env 创建一个运行多个环境实例的并行环境 - num_workers: 4 表示 4 个并行环境进程

  2. 单个环境构建(为 4 个工作进程中的每个进程重复): - **基本环境**:gym 配合 env_name: Pendulum-v1 创建一个 Pendulum 环境 - **转换层 1**:noop_reset 在每集开始时执行 30 次随机 no-op 动作 - **转换层 2**:step_counter 将每集限制为 200 步并跟踪步数 - **转换组合**:compose 将两个转换组合成一个单一的转换

  3. 最终结果:4 个并行的 Pendulum 环境,每个环境具有: - 随机 no-op 重置(开始时 0-30 次动作) - 最大每集 200 步 - 步数计数功能

关键配置概念

  1. 嵌套目标env@training_env.create_env_fn.base_env: gym 将 gym 配置深度放置在结构中

  2. 函数工厂_partial_: true 创建一个可以调用多次(每个工作进程一次)的函数

  3. 转换组合:多个转换被组合并应用于每个环境实例

  4. 变量插值${transform0}${transform1} 引用单独定义的转换配置

获取可用选项

要探索所有可用的配置及其参数,可以使用 --help 标志与任何 TorchRL 脚本结合使用

python sota-implementations/ppo_trainer/train.py --help

这将显示所有配置组及其选项,方便您发现可用的内容。它应该会打印出类似如下的内容


完整的训练示例

这是一个用于 PPO 训练的完整配置

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: gym
  - model@models.policy_model: tanh_normal
  - model@models.value_model: value
  - network@networks.policy_network: mlp
  - network@networks.value_network: mlp
  - collector: sync
  - replay_buffer: base
  - storage: tensor
  - sampler: without_replacement
  - writer: round_robin
  - trainer: ppo
  - optimizer: adam
  - loss: ppo
  - logger: wandb

# Network configurations
networks:
  policy_network:
    out_features: 2
    in_features: 4
    num_cells: [128, 128]

  value_network:
    out_features: 1
    in_features: 4
    num_calls: [128, 128]

# Model configurations
models:
  policy_model:
    network: ${networks.policy_network}
    in_keys: ["observation"]
    out_keys: ["action"]

  value_model:
    network: ${networks.value_network}
    in_keys: ["observation"]
    out_keys: ["state_value"]

# Environment
training_env:
  num_workers: 2
  create_env_fn:
    env_name: CartPole-v1
    _partial_: true

# Training components
trainer:
  collector: ${collector}
  optimizer: ${optimizer}
  loss_module: ${loss}
  logger: ${logger}
  total_frames: 100000

collector:
  create_env_fn: ${training_env}
  policy: ${models.policy_model}
  frames_per_batch: 1024

optimizer:
  lr: 0.001

loss:
  actor_network: ${models.policy_model}
  critic_network: ${models.value_model}

logger:
  exp_name: my_experiment

运行实验

基本用法

# Use default configuration
python sota-implementations/ppo_trainer/train.py

# Override specific parameters
python sota-implementations/ppo_trainer/train.py optimizer.lr=0.0001

# Change environment
python sota-implementations/ppo_trainer/train.py training_env.create_env_fn.env_name=Pendulum-v1

# Use different collector
python sota-implementations/ppo_trainer/train.py collector=async

超参数搜索

# Sweep over learning rates
python sota-implementations/ppo_trainer/train.py --multirun optimizer.lr=0.0001,0.001,0.01

# Multiple parameter sweep
python sota-implementations/ppo_trainer/train.py --multirun \
  optimizer.lr=0.0001,0.001 \
  training_env.num_workers=2,4,8

自定义配置文件

# Use custom config file
python sota-implementations/ppo_trainer/train.py --config-name my_custom_config

配置存储实现细节

在底层,TorchRL 使用 Hydra 的 ConfigStore 来注册所有配置类。这提供了类型安全、验证和 IDE 支持。当您导入 configs 模块时,注册会自动发生。

from hydra.core.config_store import ConfigStore
from torchrl.trainers.algorithms.configs import *

cs = ConfigStore.instance()

# Environments
cs.store(group="env", name="gym", node=GymEnvConfig)
cs.store(group="env", name="batched_env", node=BatchedEnvConfig)

# Models
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
# ... and many more

可用的配置类

基类

ConfigBase()

所有配置类的抽象基类。

环境配置

EnvConfig([_partial_])

环境的基类配置。

BatchedEnvConfig(_partial_, create_env_fn, ...)

批处理环境的配置。

TransformedEnvConfig([_partial_, base_env, ...])

转换环境的配置。

环境库配置

EnvLibsConfig([_partial_])

环境库的基类配置。

GymEnvConfig([_partial_, env_name, ...])

GymEnv 环境的配置。

DMControlEnvConfig([_partial_, env_name, ...])

DMControlEnv 环境的配置。

BraxEnvConfig([_partial_, env_name, ...])

BraxEnv 环境的配置。

HabitatEnvConfig([_partial_, env_name, ...])

HabitatEnv 环境的配置。

IsaacGymEnvConfig([_partial_, env_name, ...])

IsaacGymEnv 环境的配置。

JumanjiEnvConfig([_partial_, env_name, ...])

JumanjiEnv 环境的配置。

MeltingpotEnvConfig([_partial_, env_name, ...])

MeltingpotEnv 环境的配置。

MOGymEnvConfig([_partial_, env_name, ...])

MOGymEnv 环境的配置。

MultiThreadedEnvConfig([_partial_, ...])

MultiThreadedEnv 环境的配置。

OpenMLEnvConfig([_partial_, env_name, ...])

OpenMLEnv 环境的配置。

OpenSpielEnvConfig([_partial_, env_name, ...])

OpenSpielEnv 环境的配置。

PettingZooEnvConfig([_partial_, env_name, ...])

PettingZooEnv 环境的配置。

RoboHiveEnvConfig([_partial_, env_name, ...])

RoboHiveEnv 环境的配置。

SMACv2EnvConfig([_partial_, env_name, ...])

SMACv2Env 环境的配置。

UnityMLAgentsEnvConfig([_partial_, ...])

UnityMLAgentsEnv 环境的配置。

VmasEnvConfig([_partial_, env_name, ...])

VmasEnv 环境的配置。

模型和网络配置

ModelConfig([_partial_, in_keys, out_keys])

配置模型的父类。

NetworkConfig([_partial_])

配置网络的父类。

MLPConfig(_partial_, in_features, ...)

配置多层感知机的类。

ConvNetConfig(_partial_, in_features, depth, ...)

配置卷积网络的类。

TensorDictModuleConfig([_partial_, in_keys, ...])

配置 TensorDictModule 的类。

TanhNormalModelConfig([_partial_, in_keys, ...])

配置 TanhNormal 模型的类。

ValueModelConfig([_partial_, in_keys, ...])

配置价值模型的类。

转换配置

TransformConfig()

转换的基类配置。

ComposeConfig([transforms, _target_])

Compose 转换的配置。

NoopResetEnvConfig([noops, random, _target_])

NoopResetEnv 转换的配置。

StepCounterConfig([max_steps, ...])

StepCounter 转换的配置。

DoubleToFloatConfig([in_keys, out_keys, ...])

DoubleToFloat 转换的配置。

ToTensorImageConfig([from_int, unsqueeze, ...])

ToTensorImage 转换的配置。

ClipTransformConfig([in_keys, out_keys, ...])

ClipTransform 的配置。

ResizeConfig([w, h, interpolation, in_keys, ...])

Resize 转换的配置。

CenterCropConfig([height, width, in_keys, ...])

CenterCrop 转换的配置。

CropConfig([top, left, height, width, ...])

Crop 转换的配置。

FlattenObservationConfig([in_keys, ...])

FlattenObservation 转换的配置。

GrayScaleConfig([in_keys, out_keys, _target_])

GrayScale 转换的配置。

ObservationNormConfig([loc, scale, in_keys, ...])

ObservationNorm 转换的配置。

CatFramesConfig([N, dim, in_keys, out_keys, ...])

CatFrames 转换的配置。

RewardClippingConfig([clamp_min, clamp_max, ...])

RewardClipping 转换的配置。

RewardScalingConfig([loc, scale, in_keys, ...])

RewardScaling 转换的配置。

BinarizeRewardConfig([in_keys, out_keys, ...])

BinarizeReward 转换的配置。

TargetReturnConfig([target_return, mode, ...])

TargetReturn 转换的配置。

VecNormConfig([in_keys, out_keys, decay, ...])

VecNorm 转换的配置。

FrameSkipTransformConfig([frame_skip, ...])

FrameSkipTransform 的配置。

DeviceCastTransformConfig([device, in_keys, ...])

DeviceCastTransform 的配置。

DTypeCastTransformConfig([dtype, in_keys, ...])

DTypeCastTransform 的配置。

UnsqueezeTransformConfig([dim, in_keys, ...])

UnsqueezeTransform 的配置。

SqueezeTransformConfig([dim, in_keys, ...])

SqueezeTransform 的配置。

PermuteTransformConfig([dims, in_keys, ...])

PermuteTransform 的配置。

CatTensorsConfig([dim, in_keys, out_keys, ...])

CatTensors 转换的配置。

StackConfig([dim, in_keys, out_keys, _target_])

Stack 转换的配置。

DiscreteActionProjectionConfig([...])

DiscreteActionProjection 转换的配置。

TensorDictPrimerConfig([primer_spec, ...])

TensorDictPrimer 转换的配置。

PinMemoryTransformConfig([in_keys, ...])

PinMemoryTransform 的配置。

RewardSumConfig([in_keys, out_keys, _target_])

RewardSum 转换的配置。

ExcludeTransformConfig([exclude_keys, _target_])

ExcludeTransform 的配置。

SelectTransformConfig([include_keys, _target_])

SelectTransform 的配置。

TimeMaxPoolConfig([dim, in_keys, out_keys, ...])

TimeMaxPool 转换的配置。

RandomCropTensorDictConfig([crop_size, ...])

RandomCropTensorDict 转换的配置。

InitTrackerConfig([in_keys, out_keys, _target_])

InitTracker 转换的配置。

RenameTransformConfig([key_mapping, _target_])

RenameTransform 的配置。

Reward2GoTransformConfig([gamma, in_keys, ...])

Reward2GoTransform 的配置。

ActionMaskConfig([mask_key, in_keys, ...])

ActionMask 转换的配置。

VecGymEnvTransformConfig([in_keys, ...])

VecGymEnvTransform 的配置。

BurnInTransformConfig([burn_in, in_keys, ...])

BurnInTransform 的配置。

SignTransformConfig([in_keys, out_keys, ...])

SignTransform 的配置。

RemoveEmptySpecsConfig([_target_])

RemoveEmptySpecs 转换的配置。

BatchSizeTransformConfig([batch_size, ...])

BatchSizeTransform 的配置。

AutoResetTransformConfig([replace, ...])

AutoResetTransform 的配置。

ActionDiscretizerConfig([num_intervals, ...])

ActionDiscretizer 转换的配置。

TrajCounterConfig([out_key, repeats, _target_])

TrajCounter 转换的配置。

LineariseRewardsConfig([in_keys, out_keys, ...])

LineariseRewards 转换的配置。

ConditionalSkipConfig([cond, _target_])

ConditionalSkip 转换的配置。

MultiActionConfig([dim, stack_rewards, ...])

MultiAction 转换的配置。

TimerConfig([out_keys, time_key, _target_])

Timer 转换的配置。

ConditionalPolicySwitchConfig([policy, ...])

ConditionalPolicySwitch 转换的配置。

FiniteTensorDictCheckConfig([in_keys, ...])

FiniteTensorDictCheck 转换的配置。

UnaryTransformConfig([fn, in_keys, ...])

UnaryTransform 的配置。

HashConfig([in_keys, out_keys, _target_])

Hash 转换的配置。

TokenizerConfig([vocab_size, in_keys, ...])

Tokenizer 转换的配置。

EndOfLifeTransformConfig([eol_key, ...])

EndOfLifeTransform 的配置。

MultiStepTransformConfig([n_steps, gamma, ...])

MultiStepTransform 的配置。

KLRewardTransformConfig([in_keys, out_keys, ...])

KLRewardTransform 的配置。

R3MTransformConfig([in_keys, out_keys, ...])

R3MTransform 的配置。

VC1TransformConfig([in_keys, out_keys, ...])

VC1Transform 的配置。

VIPTransformConfig([in_keys, out_keys, ...])

VIPTransform 的配置。

VIPRewardTransformConfig([in_keys, ...])

VIPRewardTransform 的配置。

VecNormV2Config([in_keys, out_keys, decay, ...])

VecNormV2 转换的配置。

数据收集配置

DataCollectorConfig()

配置数据收集器的父类。

SyncDataCollectorConfig([create_env_fn, ...])

配置同步数据收集器的类。

AsyncDataCollectorConfig(create_env_fn, ...)

异步数据收集器的配置。

MultiSyncDataCollectorConfig([...])

多同步数据收集器的配置。

MultiaSyncDataCollectorConfig([...])

多异步数据收集器的配置。

回放缓冲区和存储配置

ReplayBufferConfig([_partial_, _target_, ...])

通用回放缓冲区的配置。

TensorDictReplayBufferConfig([_partial_, ...])

基于 TensorDict 的回放缓冲区的配置。

RandomSamplerConfig([_target_])

从回放缓冲区进行随机采样的配置。

SamplerWithoutReplacementConfig([_target_, ...])

无替换采样配置。

PrioritizedSamplerConfig([_target_, ...])

从回放缓冲区进行优先采样的配置。

SliceSamplerConfig([_target_, num_slices, ...])

从回放缓冲区进行切片采样的配置。

SliceSamplerWithoutReplacementConfig([...])

无替换切片采样的配置。

ListStorageConfig([_partial_, _target_, ...])

回放缓冲区中基于列表的存储配置。

TensorStorageConfig([_partial_, _target_, ...])

回放缓冲区中基于张量的存储配置。

LazyTensorStorageConfig([_partial_, ...])

延迟张量存储配置。

LazyMemmapStorageConfig([_partial_, ...])

延迟内存映射存储配置。

LazyStackStorageConfig([_partial_, ...])

延迟堆叠存储配置。

StorageEnsembleConfig([_partial_, _target_, ...])

存储集合的配置。

RoundRobinWriterConfig([_target_, compilable])

循环写入器的配置,它将数据分发到多个存储中。

StorageEnsembleWriterConfig([_partial_, ...])

存储集合写入器的配置。

训练和优化配置

TrainerConfig()

训练器的基类配置。

PPOTrainerConfig(collector, total_frames, ...)

PPO(近端策略优化)训练器的配置类。

LossConfig([_partial_])

配置损失的类。

PPOLossConfig([_partial_, actor_network, ...])

配置 PPO 损失的类。

AdamConfig([lr, betas, eps, weight_decay, ...])

Adam 优化器的配置。

AdamWConfig([lr, betas, eps, weight_decay, ...])

AdamW 优化器的配置。

AdamaxConfig([lr, betas, eps, weight_decay, ...])

Adamax 优化器的配置。

AdadeltaConfig([lr, rho, eps, weight_decay, ...])

Adadelta 优化器的配置。

AdagradConfig([lr, lr_decay, weight_decay, ...])

Adagrad 优化器的配置。

ASGDConfig([lr, lambd, alpha, t0, ...])

ASGD 优化器的配置。

LBFGSConfig([lr, max_iter, max_eval, ...])

LBFGS 优化器的配置。

LionConfig([lr, betas, weight_decay, ...])

Lion 优化器的配置。

NAdamConfig([lr, betas, eps, weight_decay, ...])

NAdam 优化器的配置。

RAdamConfig([lr, betas, eps, weight_decay, ...])

RAdam 优化器的配置。

RMSpropConfig([lr, alpha, eps, ...])

RMSprop 优化器的配置。

RpropConfig([lr, etas, step_sizes, foreach, ...])

Rprop 优化器的配置。

SGDConfig([lr, momentum, dampening, ...])

SGD 优化器的配置。

SparseAdamConfig([lr, betas, eps, _target_, ...])

SparseAdam 优化器的配置。

日志记录配置

LoggerConfig()

配置日志记录器的类。

WandbLoggerConfig(exp_name[, offline, ...])

配置 Wandb 日志记录器的类。

TensorboardLoggerConfig(exp_name[, log_dir, ...])

配置 Tensorboard 日志记录器的类。

CSVLoggerConfig(exp_name[, log_dir, ...])

配置 CSV 日志记录器的类。

创建自定义配置

您可以通过继承相应的基类来创建自定义配置类

from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.envs_libs import EnvLibsConfig

@dataclass
class MyCustomEnvConfig(EnvLibsConfig):
    _target_: str = "my_module.MyCustomEnv"
    env_name: str = "MyEnv-v1"
    custom_param: float = 1.0

    def __post_init__(self):
        super().__post_init__()

# Register with ConfigStore
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance()
cs.store(group="env", name="my_custom", node=MyCustomEnvConfig)

最佳实践

  1. 从小处着手:从基本配置开始,然后逐渐增加复杂性

  2. 使用默认值:利用 defaults 部分来组合配置

  3. 谨慎覆盖:只覆盖您需要更改的部分

  4. 验证配置:测试您的配置是否能正确实例化

  5. 版本控制:将您的配置文件保留在版本控制下

  6. 使用变量插值:使用 ${variable} 语法来避免重复

未来扩展

随着 TorchRL 添加更多算法(如 SAC、TD3、DQN),配置系统将扩展,包含

  • 新的训练器配置(例如,SACTrainerConfigTD3TrainerConfig

  • 特定于算法的损失配置

  • 针对不同算法的专用收集器配置

  • 附加的环境和模型配置

模块化设计可确保轻松集成,同时保持向后兼容性。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源