tensordict.nn.add_custom_mapping¶
- tensordict.nn.add_custom_mapping(name: str, mapping: Callable[[Tensor], Tensor])¶
在映射类中添加自定义映射。
- 参数:
name (str) – 映射的名称。
mapping (callable) – 一个可调用对象,它接收一个张量作为输入,并输出一个具有相同形状的张量。
示例
>>> from tensordict.nn import add_custom_mapping, NormalParamExtractor >>> add_custom_mapping("my_mapping", lambda x: torch.zeros_like(x)) >>> npe = NormalParamExtractor(scale_mapping="my_mapping", scale_lb=0.0) >>> assert (npe(torch.randn(10))[1] == torch.zeros(5)).all()