快捷方式

tensordict.nn.add_custom_mapping

tensordict.nn.add_custom_mapping(name: str, mapping: Callable[[Tensor], Tensor])

在映射类中添加自定义映射。

参数:
  • name (str) – 映射的名称。

  • mapping (callable) – 一个可调用对象,它接收一个张量作为输入,并输出一个具有相同形状的张量。

示例

>>> from tensordict.nn import add_custom_mapping, NormalParamExtractor
>>> add_custom_mapping("my_mapping", lambda x: torch.zeros_like(x))
>>> npe = NormalParamExtractor(scale_mapping="my_mapping", scale_lb=0.0)
>>> assert (npe(torch.randn(10))[1] == torch.zeros(5)).all()

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源