from_modules¶
- class tensordict.from_modules(*modules, as_module: bool = False, lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, expand_identical: bool = False)¶
为 vmap 的 ensemable 学习/特征期望应用检索多个模块的参数。
- 参数:
modules (sequence of nn.Module) – 要从中获取参数的模块。如果模块的结构不同,则需要一个惰性堆栈(参见下面的 `lazy_stack` 参数)。
- 关键字参数:
as_module (bool, optional) – 如果为
True
,将返回一个TensorDictParams
实例,可用于将参数存储在torch.nn.Module
中。默认为False
。lock (bool, optional) – 如果为
True
,则结果 tensordict 将被锁定。默认为True
。use_state_dict (bool, optional) –
如果为
True
,将使用模块的状态字典,并将其解压成具有模型树结构的 TensorDict。默认为False
。注意
这在使用 state-dict hook 时尤其有用。
lazy_stack (bool, optional) –
是否密集堆叠或懒惰堆叠参数。默认为
False
(密集堆叠)。注意
lazy_stack
和as_module
是互斥的特性。警告
惰性输出和非惰性输出之间有一个关键区别:非惰性输出将使用所需的批次大小重新实例化参数,而
lazy_stack
将仅表示被惰性堆叠的参数。这意味着,当 `lazy_stack=True` 时,原始参数可以安全地传递给优化器,而在 `lazy_stack=True` 时,需要传递新参数。警告
虽然为了保留原始参数引用而使用惰性堆栈可能很诱人,但请记住,每次调用 `get()` 时,惰性堆栈都会执行堆叠操作。这会消耗内存(参数大小的 N 倍,如果构建了图则更多)和计算时间。它还意味着优化器将包含更多参数,并且像 `step()` 或 `zero_grad()` 这样的操作将花费更长的时间来执行。总的来说,`lazy_stack` 应仅限于极少数用例。
expand_identical (bool, optional) – 如果为
True
并且正在将相同的参数(相同的标识)堆叠到自身,则将返回该参数的扩展版本。当 `lazy_stack=True` 时,将忽略此参数。
示例
>>> from torch import nn >>> from tensordict import from_modules >>> torch.manual_seed(0) >>> empty_module = nn.Linear(3, 4, device="meta") >>> n_models = 2 >>> modules = [nn.Linear(3, 4) for _ in range(n_models)] >>> params = from_modules(*modules) >>> print(params) TensorDict( fields={ bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) >>> # example of batch execution >>> def exec_module(params, x): ... with params.to_module(empty_module): ... return empty_module(x) >>> x = torch.randn(3) >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> # since lazy_stack = False, backprop leaves the original params untouched >>> y.sum().backward() >>> assert params["weight"].grad.norm() > 0 >>> assert modules[0].weight.grad is None
当
lazy_stack=True
时,情况略有不同>>> params = TensorDict.from_modules(*modules, lazy_stack=True) >>> print(params) LazyStackedTensorDict( fields={ bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=None, is_shared=False, stack_dim=0) >>> # example of batch execution >>> y = torch.vmap(exec_module, (0, None))(params, x) >>> assert y.shape == (n_models, 4) >>> y.sum().backward() >>> assert modules[0].weight.grad is not None