快捷方式

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[source]

一个用于具有有界动作空间的确定性策略的 Tanh 模块。

此转换用作 TensorDictModule 层,将网络输出映射到有界空间。

参数:
  • in_keys (list of str or tuples of str) – 模块的输入键。

  • out_keys (list of str or tuples of str, optional) – 模块的输出键。如果未提供,则假定键与 in_keys 相同。

关键字参数:
  • spec (TensorSpec, optional) – 如果提供,则为输出的 spec。如果提供了 Composite,其键必须与 out_keys 中的键匹配。否则,假定 out_keys 的键,并为所有输出使用相同的 spec。

  • low (float, np.ndarray or torch.Tensor) – 空间的下界。如果未提供且未提供 spec,则假定为 -1。如果提供了 spec,则将检索 spec 的最小值。

  • high (float, np.ndarray or torch.Tensor) – 空间的上界。如果未提供且未提供 spec,则假定为 1。如果提供了 spec,则将检索 spec 的最大值。

  • clamp (bool, optional) – 如果为 True,则输出将被限制在边界内,但与边界之间至少有一个最小分辨率。默认为 False

示例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import Bounded
>>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = Composite(
...     a=Bounded(-3, 0, shape=()),
...     b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict=None)[source]

定义每次调用时执行的计算。

所有子类都应重写此方法。

注意

虽然前向传播的配方需要在此函数内定义,但应在之后调用 Module 实例而不是此函数,因为前者负责运行已注册的钩子,而后者会默默忽略它们。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源