tensordict.nn.EnsembleModule¶
- class tensordict.nn.EnsembleModule(*args, **kwargs)¶
包装一个模块并重复它以形成集合的模块。
- 参数:
module (nn.Module) – 要复制和包装的 nn.module。
num_copies (int) – 要制作的模块副本数量。
parameter_init_function (Callable) – 一个函数,它接受一个模块副本并初始化其参数。
expand_input (bool) – 是否将输入 TensorDict 扩展以匹配副本数量。除非您将 ensemble 模块链接在一起,例如 EnsembleModule(cnn) -> EnsembleModule(mlp),否则应将其设置为 True。如果为 False,EnsembleModule(mlp) 将期望之前的模块已经扩展了输入。
示例
>>> import torch >>> from torch import nn >>> from tensordict.nn import TensorDictModule, EnsembleModule >>> from tensordict import TensorDict >>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2)) >>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b']) >>> ensemble = EnsembleModule(mod, num_copies=3) >>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10]) >>> ensemble(data) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 10]), device=None, is_shared=False)
要将 EnsembleModules 堆叠在一起,我们应该注意从第二个模块开始将 expand_input 设置为 False,然后开启。
示例
>>> import torch >>> from tensordict.nn import TensorDictModule, TensorDictSequential, EnsembleModule >>> from tensordict import TensorDict >>> module = TensorDictModule(torch.nn.Linear(2,3), in_keys=['bork'], out_keys=['dork']) >>> next_module = TensorDictModule(torch.nn.Linear(3,1), in_keys=['dork'], out_keys=['spork']) >>> e0 = EnsembleModule(module, num_copies=4, expand_input=True) >>> e1 = EnsembleModule(next_module, num_copies=4, expand_input=False) >>> seq = TensorDictSequential(e0, e1) >>> data = TensorDict({'bork': torch.randn(5,2)}, batch_size=[5]) >>> seq(data) TensorDict( fields={ bork: Tensor(shape=torch.Size([4, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False), dork: Tensor(shape=torch.Size([4, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False), spork: Tensor(shape=torch.Size([4, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 5]), device=None, is_shared=False)