快捷方式

AddThinkingPrompt

class torchrl.envs.llm.transforms.AddThinkingPrompt(cond: Callable[[TensorDictBase], bool], prompt: str | None = None, random_prompt: bool = False, role: Literal['user', 'assistant'] = 'assistant', edit_last_turn: bool = True, zero_reward: bool | None = None, undo_done: bool = True, egocentric: bool | None = None)[源码]

A transform that adds thinking prompts to encourage the LLM to reconsider its response. (一个添加思考提示以鼓励 LLM 重新考虑其响应的转换。)

This transform can either add a new thinking prompt as a separate message or edit the last assistant response to include a thinking prompt before the final answer. This is useful for training LLMs to self-correct and think more carefully when their initial responses are incorrect or incomplete. (此转换可以添加一个新的思考提示作为单独的消息,也可以编辑最后一个助理解释以在最终答案之前包含思考提示。这对于训练 LLM 在初始响应不正确或不完整时进行自我纠正和更仔细地思考非常有用。)

参数:
  • cond (Callable[[TensorDictBase], bool], optional) – Condition function that determines when to add the thinking prompt. Takes a tensordict and returns True if the prompt should be added. (条件函数,用于确定何时添加思考提示。接收一个 tensordict 并返回 True,如果应添加提示。)

  • prompt (str, optional) – The thinking prompt to add. If None, a default prompt is used. Defaults to “But wait, let me think about this more carefully…”. (要添加的思考提示。如果为 None,则使用默认提示。默认为 “等等,让我更仔细地考虑一下……”。)

  • random_prompt (bool, optional) – Whether to randomly select from predefined prompts. Defaults to False. (是否从预定义提示中随机选择。默认为 False。)

  • role (Literal["user", "assistant"], optional) – The role for the thinking prompt. If “assistant”, the prompt is added to the assistant’s response. If “user”, it’s added as a separate user message. Defaults to “assistant”. (思考提示的角色。如果为 “assistant”,则提示将添加到助理解释中。如果为 “user”,则将提示添加为单独的用户消息。默认为 “assistant”。)

  • edit_last_turn (bool, optional) – Whether to edit the last assistant response instead of adding a new message. Only works with role=”assistant”. Defaults to True. (是否编辑最后一个助理解释而不是添加新消息。仅在 role=”assistant” 时有效。默认为 True。)

  • zero_reward (bool, optional) – Whether to zero out the reward when the thinking prompt is added. If None, defaults to the value of edit_last_turn. Defaults to the same value as edit_last_turn. (添加思考提示时是否将奖励清零。如果为 None,则默认为 edit_last_turn 的值。默认为与 edit_last_turn 相同的值。)

  • undo_done (bool, optional) – Whether to undo the done flag when the thinking prompt is added. Defaults to True. (添加思考提示时是否撤销 done 标志。默认为 True。)

  • egocentric (bool, optional) – Whether the thinking prompt is written from the perspective of the assistant. Defaults to None, which means that the prompt is written from the perspective of the user if role=”user” and from the perspective of the assistant if role=”assistant”. (思考提示是否从助手的角度编写。默认为 None,这意味着如果 role=”user”,则提示将从用户的角度编写;如果 role=”assistant”,则从助手的角度编写。)

示例

>>> from torchrl.envs.llm.transforms import AddThinkingPrompt
>>> from torchrl.envs.llm import GSM8KEnv
>>> from transformers import AutoTokenizer
>>> import torch
>>>
>>> # Create environment with thinking prompt transform
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
>>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
>>> env = env.append_transform(
...     AddThinkingPrompt(
...         cond=lambda td: td["reward"] < 50,
...         role="assistant",
...         edit_last_turn=True,
...         zero_reward=True,
...         undo_done=True
...     )
... )
>>>
>>> # Test with wrong answer (low reward)
>>> reset = env.reset()
>>> wrong_answer = (
...     "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
...     "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
...     "To find the total, I need to add April and May: 48 + 24 = 72. "
...     "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
...     "<answer>322 clips</answer><|im_end|>"
... )
>>> reset["text_response"] = [wrong_answer]
>>> s = env.step(reset)
>>> assert (s["next", "reward"] == 0).all()  # Reward zeroed
>>> assert (s["next", "done"] == 0).all()    # Done undone
>>> assert s["next", "history"].shape == (1, 3)  # History modified
>>>
>>> # Test with correct answer (high reward)
>>> reset = env.reset()
>>> correct_answer = (
...     "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
...     "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
...     "To find the total, I need to add April and May: 48 + 24 = 72. "
...     "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
...     "<answer>72</answer><|im_end|>"
... )
>>> reset["text_response"] = [correct_answer]
>>> s = env.step(reset)
>>> assert (s["next", "reward"] != 0).all()  # Reward not zeroed
>>> assert s["next", "done"].all()           # Done remains True
>>> assert s["next", "history"].shape == (1, 3)  # History unchanged
add_module(name: str, module: Optional[Module]) None

将子模块添加到当前模块。

可以使用给定的名称作为属性访问该模块。

参数:
  • name (str) – 子模块的名称。子模块可以通过给定名称从此模块访问

  • module (Module) – 要添加到模块中的子模块。

apply(fn: Callable[[Module], None]) Self

fn 递归应用于每个子模块(由 .children() 返回)以及自身。

典型用法包括初始化模型参数(另请参阅 torch.nn.init)。

参数:

fn (Module -> None) – 要应用于每个子模块的函数

返回:

self

返回类型:

模块

示例

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() Self

将所有浮点参数和缓冲区转换为 bfloat16 数据类型。

注意

此方法就地修改模块。

返回:

self

返回类型:

模块

buffers(recurse: bool = True) Iterator[Tensor]

返回模块缓冲区的迭代器。

参数:

recurse (bool) – 如果为 True,则会产生此模块及其所有子模块的 buffer。否则,仅会产生此模块的直接成员 buffer。

产生:

torch.Tensor – 模块缓冲区

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children() Iterator[Module]

返回直接子模块的迭代器。

产生:

Module – 子模块

close()

关闭转换。

property collector: DataCollectorBase | None

返回与容器关联的收集器(如果存在)。

每当变换需要了解收集器或与之关联的策略时,都可以使用此属性。请确保仅在未嵌套在子进程中的变换上调用此属性。收集器引用不会传递给 ParallelEnv 或类似的批处理环境的 worker。

请确保仅在未嵌套在子进程中的转换上调用此属性。 Collector 引用不会传递给 ParallelEnv 或类似批量环境的 worker。

compile(*args, **kwargs)

使用 torch.compile() 编译此 Module 的前向传播。

此 Module 的 __call__ 方法将被编译,并且所有参数将按原样传递给 torch.compile()

有关此函数的参数的详细信息,请参阅 torch.compile()

property container: EnvBase | None

返回包含该变换的环境。

示例

>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[0].container is env
True
cpu() Self

将所有模型参数和缓冲区移动到 CPU。

注意

此方法就地修改模块。

返回:

self

返回类型:

模块

cuda(device: Optional[Union[device, int]] = None) Self

将所有模型参数和缓冲区移动到 GPU。

这也会使相关的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 GPU 上,则应在构建优化器之前调用此函数。

注意

此方法就地修改模块。

参数:

device (int, optional) – 如果指定,所有参数将复制到该设备

返回:

self

返回类型:

模块

double() Self

将所有浮点参数和缓冲区转换为 double 数据类型。

注意

此方法就地修改模块。

返回:

self

返回类型:

模块

eval() Self

将模块设置为评估模式。

这仅对某些模块有影响。有关模块在训练/评估模式下的行为,例如它们是否受影响(如 DropoutBatchNorm 等),请参阅具体模块的文档。

这等同于 self.train(False)

有关 .eval() 和几种可能与之混淆的类似机制之间的比较,请参阅 局部禁用梯度计算

返回:

self

返回类型:

模块

extra_repr() str

返回模块的额外表示。

要打印自定义额外信息,您应该在自己的模块中重新实现此方法。单行和多行字符串均可接受。

float() Self

将所有浮点参数和缓冲区转换为 float 数据类型。

注意

此方法就地修改模块。

返回:

self

返回类型:

模块

forward(tensordict: TensorDictBase = None) TensorDictBase

读取输入 tensordict,并对选定的键应用转换。

默认情况下,此方法

  • 直接调用 _apply_transform()

  • 不调用 _step()_call()

此方法不会在任何时候在 env.step 中调用。但是,它会在 sample() 中调用。

注意

forward 也可以使用 dispatch 将参数名称转换为键,并使用常规关键字参数。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
get_buffer(target: str) Tensor

返回由 target 给定的缓冲区(如果存在),否则抛出错误。

有关此方法功能的更详细解释以及如何正确指定 target,请参阅 get_submodule 的文档字符串。

参数:

target – 要查找的 buffer 的完全限定字符串名称。(要指定完全限定字符串,请参阅 get_submodule。)

返回:

target 引用的缓冲区

返回类型:

torch.Tensor

抛出:

AttributeError – 如果目标字符串引用了无效路径或解析为非 buffer 对象。

get_extra_state() Any

返回要包含在模块 state_dict 中的任何额外状态。

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict(). (如果需要存储额外状态,请实现此函数和相应的 set_extra_state() 函数。在构建模块的 state_dict() 时会调用此函数。)

注意,为了保证 state_dict 的序列化工作正常,额外状态应该是可被 pickle 的。我们仅为 Tensors 的序列化提供向后兼容性保证;其他对象的序列化形式若发生变化,可能导致向后兼容性中断。

返回:

要存储在模块 state_dict 中的任何额外状态

返回类型:

对象

get_parameter(target: str) Parameter

如果存在,返回由 target 给定的参数,否则抛出错误。

有关此方法功能的更详细解释以及如何正确指定 target,请参阅 get_submodule 的文档字符串。

参数:

target – 要查找的 Parameter 的完全限定字符串名称。(要指定完全限定字符串,请参阅 get_submodule。)

返回:

target 引用的参数

返回类型:

torch.nn.Parameter

抛出:

AttributeError – 如果目标字符串引用了无效路径或解析为非 nn.Parameter 的对象。

get_submodule(target: str) Module

如果存在,返回由 target 给定的子模块,否则抛出错误。

例如,假设您有一个 nn.Module A,它看起来像这样

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(图示了一个 nn.Module AA 包含一个嵌套子模块 net_b,该子模块本身有两个子模块 net_clinearnet_c 随后又有一个子模块 conv。)

要检查是否存在 linear 子模块,可以调用 get_submodule("net_b.linear")。要检查是否存在 conv 子模块,可以调用 get_submodule("net_b.net_c.conv")

get_submodule 的运行时复杂度受 target 中模块嵌套深度的限制。与 named_modules 的查询相比,后者的复杂度是按传递模块数量计算的 O(N)。因此,对于简单地检查某个子模块是否存在,应始终使用 get_submodule

参数:

target – 要查找的子模块的完全限定字符串名称。(要指定完全限定字符串,请参阅上面的示例。)

返回:

target 引用的子模块

返回类型:

torch.nn.Module

抛出:

AttributeError – 如果在目标字符串解析的任何路径中,子路径解析为不存在的属性名或不是 nn.Module 实例的对象。

half() Self

将所有浮点参数和缓冲区转换为 half 数据类型。

注意

此方法就地修改模块。

返回:

self

返回类型:

模块

init(tensordict) None

运行转换的初始化步骤。

inv(tensordict: TensorDictBase = None) TensorDictBase

读取输入 tensordict,并对选定的键应用逆变换。

默认情况下,此方法

  • 直接调用 _inv_apply_transform()

  • 不调用 _inv_call()

注意

inv 也通过使用 dispatch 将参数名称强制转换为键来处理常规关键字参数。

注意

invextend() 调用。

ipu(device: Optional[Union[device, int]] = None) Self

将所有模型参数和缓冲区移动到 IPU。

这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 IPU 上,则应在构建优化器之前调用它。

注意

此方法就地修改模块。

参数:

device (int, optional) – 如果指定,所有参数将复制到该设备

返回:

self

返回类型:

模块

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)

Copy parameters and buffers from state_dict into this module and its descendants. (从 state_dict 复制参数和缓冲区到此模块及其子模块中。)

If strict is True, then the keys of state_dict must match the keys returned by this module’s state_dict() function exactly. (如果 strictTrue,则 state_dict 的键必须与此模块的 state_dict() 函数返回的键完全匹配。)

警告

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True. (如果 assignTrue,则必须在调用 load_state_dict 之后创建优化器,除非 get_swap_module_params_on_conversion()True。)

参数:
  • state_dict (dict) – 包含参数和持久 buffer 的字典。

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True (是否严格强制 state_dict 中的键与此模块的 state_dict() 函数返回的键匹配。默认值:True。)

  • assign (bool, optional) – 当设置为 False 时,将保留当前模块中张量的属性;当设置为 True 时,将保留 state_dict 中张量的属性。唯一的例外是 Parameterrequires_grad 字段,此时将保留模块的值。默认值:False

返回:

  • missing_keys 是一个包含此模块期望但

    在提供的 state_dict 中缺失的任何键的字符串列表。

  • unexpected_keys 是一个字符串列表,包含此模块

    不期望但在提供的 state_dict 中存在的键。

返回类型:

NamedTuple,包含 missing_keysunexpected_keys 字段。

注意

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError. (如果参数或缓冲区注册为 None 并且其对应的键存在于 state_dict 中,load_state_dict() 将引发 RuntimeError。)

modules() Iterator[Module]

返回网络中所有模块的迭代器。

产生:

Module – 网络中的一个模块

注意

重复的模块只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
mtia(device: Optional[Union[device, int]] = None) Self

将所有模型参数和缓冲区移动到 MTIA。

这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 MTIA 上,则应在构建优化器之前调用它。

注意

此方法就地修改模块。

参数:

device (int, optional) – 如果指定,所有参数将复制到该设备

返回:

self

返回类型:

模块

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.Tensor]]

返回模块缓冲区上的迭代器,同时生成缓冲区的名称和缓冲区本身。

参数:
  • prefix (str) – 为所有 buffer 名称添加前缀。

  • recurse (bool, optional) – 如果为 True,则会生成此模块及其所有子模块的 buffers。否则,仅生成此模块直接成员的 buffers。默认为 True。

  • remove_duplicate (bool, optional) – 是否在结果中删除重复的 buffers。默认为 True。

产生:

(str, torch.Tensor) – 包含名称和缓冲区的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[tuple[str, 'Module']]

返回对直接子模块的迭代器,生成模块的名称和模块本身。

产生:

(str, Module) – 包含名称和子模块的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Optional[set['Module']] = None, prefix: str = '', remove_duplicate: bool = True)

返回网络中所有模块的迭代器,同时生成模块的名称和模块本身。

参数:
  • memo – 用于存储已添加到结果中的模块集合的 memo

  • prefix – 将添加到模块名称的名称前缀

  • remove_duplicate – 是否从结果中删除重复的模块实例

产生:

(str, Module) – 名称和模块的元组

注意

重复的模块只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]]

返回模块参数的迭代器,同时生成参数的名称和参数本身。

参数:
  • prefix (str) – 为所有参数名称添加前缀。

  • recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。

  • remove_duplicate (bool, optional) – 是否在结果中删除重复的参数。默认为 True。

产生:

(str, Parameter) – 包含名称和参数的元组

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter]

返回模块参数的迭代器。

这通常传递给优化器。

参数:

recurse (bool) – 如果为 True,则会生成此模块及其所有子模块的参数。否则,仅生成此模块直接成员的参数。

产生:

Parameter – 模块参数

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
property parent: TransformedEnv | None

返回变换的父环境。

父环境是包含直到当前变换的所有变换的环境。

示例

>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[1].parent
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            RewardSum(keys=['reward'])))
register_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]]) RemovableHandle

在模块上注册一个反向传播钩子。

此函数已弃用,建议使用 register_full_backward_hook(),并且此函数在未来版本中的行为将发生变化。

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None

向模块添加一个缓冲区。

This is typically used to register a buffer that should not be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict. (这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm 的 running_mean 不是参数,但它是模块状态的一部分。默认情况下,缓冲区是持久的,并且将与参数一起保存。通过将 persistent 设置为 False 可以更改此行为。持久缓冲区和非持久缓冲区之间的唯一区别是后者将不包含在此模块的 state_dict 中。)

可以使用给定名称作为属性访问缓冲区。

参数:
  • name (str) – buffer 的名称。可以使用给定的名称从此模块访问 buffer

  • tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict. (要注册的缓冲区。如果为 None,则在缓冲区上运行的操作,例如 cuda,将被忽略。如果为 None,则该缓冲区包含在模块的 state_dict 中。)

  • persistent (bool) – whether the buffer is part of this module’s state_dict. (缓冲区是否是此模块 state_dict 的一部分。)

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Union[Callable[[T, tuple[Any, ...]], Any], Callable[[T, tuple[Any, ...], dict[str, Any]], Any]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

在模块上注册一个前向钩子。

The hook will be called every time after forward() has computed an output. (每次在 forward() 计算完输出后都会调用此钩子。)

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature (如果 with_kwargsFalse 或未指定,则输入仅包含传递给模块的位置参数。关键字参数不会传递给钩子,只会传递给 forward。钩子可以修改输出。它可以原地修改输入,但这不会影响 forward,因为此操作是在调用 forward() 之后调用的。钩子应具有以下签名)

hook(module, args, output) -> None or modified output

如果 with_kwargsTrue,则前向钩子将接收传递给 forward 函数的 kwargs,并需要返回可能已修改的输出。钩子应该具有以下签名

hook(module, args, kwargs, output) -> None or modified output
参数:
  • hook (Callable) – 用户定义的待注册钩子。

  • prepend (bool) – If True, the provided hook will be fired before all existing forward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False (如果为 True,则提供的 hook 将在对此 torch.nn.Module 的所有现有 forward 钩子之前触发。否则,提供的 hook 将在此 torch.nn.Module 的所有现有 forward 钩子之后触发。请注意,使用 register_module_forward_hook() 注册的全局 forward 钩子将在通过此方法注册的所有钩子之前触发。默认值:False。)

  • with_kwargs (bool) – 如果为 True,则 hook 将接收传递给 forward 函数的 kwargs。默认为 False

  • always_call (bool) – 如果为 True,则无论在调用 Module 时是否引发异常,都会运行 hook。默认为 False

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook: Union[Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

在模块上注册一个前向预钩子。

The hook will be called every time before forward() is invoked. (每次在调用 forward() 之前都会调用此钩子。)

如果 with_kwargs 为 false 或未指定,则输入仅包含传递给模块的位置参数。关键字参数不会传递给钩子,而只会传递给 forward。钩子可以修改输入。用户可以返回一个元组或单个修改后的值。我们将把值包装成一个元组,如果返回的是单个值(除非该值本身就是元组)。钩子应该具有以下签名

hook(module, args) -> None or modified input

如果 with_kwargs 为 true,则前向预钩子将接收传递给 forward 函数的 kwargs。如果钩子修改了输入,则应该返回 args 和 kwargs。钩子应该具有以下签名

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
参数:
  • hook (Callable) – 用户定义的待注册钩子。

  • prepend (bool) – If true, the provided hook will be fired before all existing forward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False (如果为 True,则提供的 hook 将在对此 torch.nn.Module 的所有现有 forward_pre 钩子之前触发。否则,提供的 hook 将在此 torch.nn.Module 的所有现有 forward_pre 钩子之后触发。请注意,使用 register_module_forward_pre_hook() 注册的全局 forward_pre 钩子将在通过此方法注册的所有钩子之前触发。默认值:False。)

  • with_kwargs (bool) – 如果为 True,则 hook 将接收传递给 forward 函数的 kwargs。默认为 False

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模块上注册一个反向传播钩子。

每次计算相对于模块的梯度时,将调用此钩子,其触发规则如下:

  1. 通常,钩子在计算相对于模块输入的梯度时触发。

  2. 如果模块输入都不需要梯度,则在计算相对于模块输出的梯度时触发钩子。

  3. 如果模块输出都不需要梯度,则钩子将不触发。

钩子应具有以下签名

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments. ( grad_inputgrad_output 是包含相对于输入和输出的梯度的元组。钩子不应修改其参数,但它可以选择返回一个相对于输入的新的梯度,该梯度将在后续计算中替代 grad_inputgrad_input 将仅对应于作为位置参数给出的输入,并且所有关键字参数都将被忽略。对于所有非 Tensor 参数,grad_inputgrad_output 中的条目将为 None。)

由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。

警告

使用反向传播钩子时不允许就地修改输入或输出,否则将引发错误。

参数:
  • hook (Callable) – 要注册的用户定义钩子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method. (如果为 True,则提供的 hook 将在对此 torch.nn.Module 的所有现有 backward 钩子之前触发。否则,提供的 hook 将在此 torch.nn.Module 的所有现有 backward 钩子之后触发。请注意,使用 register_module_full_backward_hook() 注册的全局 backward 钩子将在通过此方法注册的所有钩子之前触发。)

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模块上注册一个反向预钩子。

每次计算模块的梯度时,将调用此钩子。钩子应具有以下签名

hook(module, grad_output) -> tuple[Tensor] or None

grad_output 是一个元组。钩子不应修改其参数,但可以选择返回一个新的输出梯度,该梯度将取代 grad_output 用于后续计算。对于所有非 Tensor 参数,grad_output 中的条目将为 None

由于技术原因,当此钩子应用于模块时,其前向函数将接收传递给模块的每个张量的视图。类似地,调用者将接收模块前向函数返回的每个张量的视图。

警告

使用反向传播钩子时不允许就地修改输入,否则将引发错误。

参数:
  • hook (Callable) – 要注册的用户定义钩子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method. (如果为 True,则提供的 hook 将在对此 torch.nn.Module 的所有现有 backward_pre 钩子之前触发。否则,提供的 hook 将在此 torch.nn.Module 的所有现有 backward_pre 钩子之后触发。请注意,使用 register_module_full_backward_pre_hook() 注册的全局 backward_pre 钩子将在通过此方法注册的所有钩子之前触发。)

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)

注册一个后钩子,用于在模块的 load_state_dict() 被调用后运行。

它应该具有以下签名:

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. ( module 参数是当前注册此钩子的模块,而 incompatible_keys 参数是包含 missing_keysunexpected_keys 属性的 NamedTuplemissing_keys 是一个包含缺失键的 str 列表,而 unexpected_keys 是一个包含意外键的 str 列表。)

如果需要,可以就地修改给定的 incompatible_keys。

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error. (请注意,当以 strict=True 调用 load_state_dict() 时执行的检查会受到钩子对 missing_keysunexpected_keys 的修改的影响,正如预期的那样。向任一键集添加内容将导致在 strict=True 时抛出错误,而清空缺失和意外键将避免错误。)

返回:

一个句柄,可用于通过调用 handle.remove() 来移除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)

注册一个预钩子,用于在模块的 load_state_dict() 被调用之前运行。

它应该具有以下签名:

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

参数:

hook (Callable) – 在加载状态字典之前将调用的可调用钩子。

register_module(name: str, module: Optional[Module]) None

Alias for add_module(). ( add_module() 的别名。)

register_parameter(name: str, param: Optional[Parameter]) None

向模块添加一个参数。

可以使用给定名称作为属性访问该参数。

参数:
  • name (str) – 参数的名称。可以通过给定名称从该模块访问该参数。

  • param (Parameter or None) – parameter to be added to the module. If None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict. (要添加到模块的参数。如果为 None,则在参数上运行的操作,例如 cuda,将被忽略。如果为 None,则该参数包含在模块的 state_dict 中。)

register_state_dict_post_hook(hook)

注册 state_dict() 方法的后置钩子。

它应该具有以下签名:

hook(module, state_dict, prefix, local_metadata) -> None

注册的钩子可以就地修改 state_dict

register_state_dict_pre_hook(hook)

注册 state_dict() 方法的前置钩子。

它应该具有以下签名:

hook(module, prefix, keep_vars) -> None

注册的钩子可用于在进行 state_dict 调用之前执行预处理。

requires_grad_(requires_grad: bool = True) Self

更改自动梯度是否应记录此模块中参数的操作。

此方法就地设置参数的 requires_grad 属性。

此方法有助于冻结模块的一部分以进行微调或单独训练模型的一部分(例如,GAN 训练)。

请参阅 本地禁用梯度计算 以比较 .requires_grad_() 和几种可能与之混淆的类似机制。

参数:

requires_grad (bool) – 自动求导是否应记录此模块上的参数操作。默认为 True

返回:

self

返回类型:

模块

set_extra_state(state: Any) None

设置加载的 state_dict 中包含的额外状态。

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict. (此函数由 load_state_dict() 调用,用于处理 state_dict 中的任何额外状态。如果需要在此模块的 state_dict 中存储额外状态,请实现此函数和相应的 get_extra_state() 函数。)

参数:

state (dict) – 来自 state_dict 的额外状态

set_submodule(target: str, module: Module, strict: bool = False) None

如果存在,设置由 target 给定的子模块,否则抛出错误。

注意

如果 strict 设置为 False(默认),该方法将替换现有子模块或在父模块存在的情况下创建新子模块。如果 strict 设置为 True,该方法将仅尝试替换现有子模块,并在子模块不存在时引发错误。

例如,假设您有一个 nn.Module A,它看起来像这样

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(3, 3, 3)
        )
        (linear): Linear(3, 3)
    )
)

(图示了一个 nn.Module AA 包含一个嵌套子模块 net_b,该子模块本身有两个子模块 net_clinearnet_c 随后又有一个子模块 conv。)

要用一个新的 Linear 子模块覆盖 Conv2d,可以调用 set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中 strict 可以是 TrueFalse

要将一个新的 Conv2d 子模块添加到现有的 net_b 模块中,可以调用 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))

在上面,如果设置 strict=True 并调用 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),则会引发 AttributeError,因为 net_b 中不存在名为 conv 的子模块。

参数:
  • target – 要查找的子模块的完全限定字符串名称。(要指定完全限定字符串,请参阅上面的示例。)

  • module – 要设置子模块的对象。

  • strict – 如果为 False,该方法将替换现有子模块或创建新子模块(如果父模块存在)。如果为 True,则该方法只会尝试替换现有子模块,如果子模块不存在则抛出错误。

抛出:
  • ValueError – 如果 target 字符串为空或 module 不是 nn.Module 的实例。

  • AttributeError – 如果 target 字符串路径中的任何一点解析为一个不存在的属性名或不是 nn.Module 实例的对象。

share_memory() Self

请参阅 torch.Tensor.share_memory_()

state_dict(*args, destination=None, prefix='', keep_vars=False)

返回一个字典,其中包含对模块整个状态的引用。

参数和持久缓冲区(例如,运行平均值)都包含在内。键是相应的参数和缓冲区名称。设置为 None 的参数和缓冲区不包含在内。

注意

返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。

警告

当前 state_dict() 还接受 destinationprefixkeep_vars 的位置参数,顺序为。但是,这正在被弃用,并且在未来的版本中将强制使用关键字参数。

警告

请避免使用参数 destination,因为它不是为最终用户设计的。

参数:
  • destination (dict, optional) – 如果提供,模块的状态将更新到 dict 中,并返回相同的对象。否则,将创建一个 OrderedDict 并返回。默认为 None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 默认情况下,state dict 中返回的 Tensors 会从 autograd 中分离。如果设置为 True,则不会执行分离。默认为 False

返回:

包含模块整体状态的字典

返回类型:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
to(*args, **kwargs) Transform

移动和/或转换参数和缓冲区。

这可以这样调用

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. (其签名与 torch.Tensor.to() 类似,但只接受浮点或复数 dtype。此外,此方法只会将浮点或复数参数和缓冲区转换为(如果给定)dtype。如果给定了 device,整数参数和缓冲区将被移动到 device,但 dtype 不变。当设置 non_blocking 时,它会尝试尽可能异步地(相对于主机)转换/移动,例如,将具有固定内存的 CPU Tensor 移动到 CUDA 设备。)

有关示例,请参阅下文。

注意

此方法就地修改模块。

参数:
  • device (torch.device) – the desired device of the parameters and buffers in this module – 此模块中参数和缓冲区的目标设备。

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module – 此模块中参数和缓冲区的目标浮点数或复数 dtype。

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module – 其 dtype 和 device 是此模块中所有参数和缓冲区的目标 dtype 和 device 的 Tensor。

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument) – 此模块中 4D 参数和缓冲区的目标内存格式(仅关键字参数)。

返回:

self

返回类型:

模块

示例

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) Self

将参数和缓冲区移动到指定设备,而不复制存储。

参数:
  • device (torch.device) – The desired device of the parameters and buffers in this module. – 此模块中参数和缓冲区的目标设备。

  • recurse (bool) – 是否递归地将子模块的参数和缓冲区移动到指定设备。

返回:

self

返回类型:

模块

train(mode: bool = True) Self

将模块设置为训练模式。

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc. – 这只对某些模块有影响。有关其在训练/评估模式下的行为的详细信息,例如它们是否受影响,请参阅特定模块的文档,例如 DropoutBatchNorm 等。

参数:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True. – 设置训练模式(True)或评估模式(False)。默认值:True

返回:

self

返回类型:

模块

transform_action_spec(action_spec: TensorSpec) TensorSpec

转换动作规范,使结果规范与变换映射匹配。

参数:

action_spec (TensorSpec) – 变换前的规范

返回:

转换后的预期规范

transform_done_spec(done_spec: TensorSpec) TensorSpec

变换 done spec,使结果 spec 与变换映射匹配。

参数:

done_spec (TensorSpec) – 变换前的 spec

返回:

转换后的预期规范

transform_env_batch_size(batch_size: Size) Size

转换父环境的 batch-size。

transform_env_device(device: device) device

转换父环境的 device。

transform_input_spec(input_spec: TensorSpec) TensorSpec

转换输入规范,使结果规范与转换映射匹配。

参数:

input_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_observation_spec(observation_spec: TensorSpec) TensorSpec

转换观察规范,使结果规范与转换映射匹配。

参数:

observation_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_output_spec(output_spec: Composite) Composite

转换输出规范,使结果规范与转换映射匹配。

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec (此方法通常应保持不变。更改应使用 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 实现。 :param output_spec: 转换前的 spec :type output_spec: TensorSpec)

返回:

转换后的预期规范

transform_reward_spec(reward_spec: TensorSpec) TensorSpec

转换奖励的 spec,使其与变换映射匹配。

参数:

reward_spec (TensorSpec) – 变换前的 spec

返回:

转换后的预期规范

transform_state_spec(state_spec: TensorSpec) TensorSpec

转换状态规范,使结果规范与变换映射匹配。

参数:

state_spec (TensorSpec) – 变换前的规范

返回:

转换后的预期规范

type(dst_type: Union[dtype, str]) Self

将所有参数和缓冲区转换为 dst_type

注意

此方法就地修改模块。

参数:

dst_type (type or string) – 目标类型

返回:

self

返回类型:

模块

xpu(device: Optional[Union[device, int]] = None) Self

将所有模型参数和缓冲区移动到 XPU。

这也会使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在 XPU 上,则应在构建优化器之前调用它。

注意

此方法就地修改模块。

参数:

device (int, optional) – 如果指定,所有参数将复制到该设备

返回:

self

返回类型:

模块

zero_grad(set_to_none: bool = True) None

重置所有模型参数的梯度。

See similar function under torch.optim.Optimizer for more context. – 有关更多背景信息,请参阅 torch.optim.Optimizer 下的类似函数。

参数:

set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details. – 与其设置为零,不如将 grad 设置为 None。有关详细信息,请参阅 torch.optim.Optimizer.zero_grad()

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源