ConsistentDropoutModule¶
- class torchrl.modules.ConsistentDropoutModule(*args, **kwargs)[源代码]¶
用于
ConsistentDropout
的 TensorDictModule 包装器。- 参数:
p (
float
, optional) – Dropout 概率。默认为0.5
。in_keys (NestedKey 或 list of NestedKeys) – 将从输入 tensordict 读取并传递给此模块的键。
out_keys (NestedKey 或 iterable of NestedKeys) – 将写入输入 tensordict 的键。默认为
in_keys
值。
- 关键字参数:
input_shape (tuple, optional) – 输入(非批处理)的形状,用于通过
make_tensordict_primer()
生成 tensordict primer。input_dtype (torch.dtype, optional) – primer 的输入数据类型。如果未提供,则假定为
torch.get_default_dtype
。
注意
要在策略中使用此类,需要在重置时重置掩码。这可以通过
TensorDictPrimer
转换来实现,该转换可以通过make_tensordict_primer()
获取。有关更多信息,请参阅该方法。示例
>>> from tensordict import TensorDict >>> module = ConsistentDropoutModule(p = 0.1) >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) >>> module(td) TensorDict( fields={ mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict)[源代码]¶
定义每次调用时执行的计算。
所有子类都应重写此方法。
注意
尽管前向传播的实现需要在此函数中定义,但您应该在之后调用
Module
实例而不是此函数,因为前者会处理注册的钩子,而后者则会静默忽略它们。
- make_tensordict_primer()[源代码]¶
创建一个 tensordict primer,供环境在重置调用期间生成随机掩码。
另请参阅
torchrl.modules.utils.get_primers_from_module()
用于生成给定模块的所有 primer 的方法。模块。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv >>> m = Seq( ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]), ... ConsistentDropoutModule( ... p=0.5, ... input_shape=(2, 4), ... in_keys="intermediate", ... ), ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), ... ) >>> primer = get_primers_from_module(m) >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5)) >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) >>> env = env.append_transform(primer) >>> r = env.rollout(10, m, break_when_any_done=False) >>> mask = [k for k in r.keys() if k.startswith("mask")][0] >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any() >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()