Stacked¶
- class torchrl.data.Stacked(*specs: tuple[T, ...], dim: int)[source]¶
A lazy representation of a stack of tensor specs.
Stacks tensor-specs together along one dimension. When random samples are drawn, a stack of samples is returned if possible. If not, an error is thrown.
Indexing is allowed but only along the stack dimension.
This class aims at being used in multi-tasks and multi-agent settings, where heterogeneous specs may occur (same semantic but different shape).
- assert_is_in(value: Tensor) None ¶
断言一个张量是否属于该区域(box),否则抛出异常。
- 参数:
value (torch.Tensor) – 要检查的值。
- clear_device_()¶
清除 Composite 的设备。
- clone() T ¶
创建 TensorSpec 的副本。
- contains(item: torch.Tensor | TensorDictBase) bool ¶
If the value
val
could have been generated by theTensorSpec
, returnsTrue
, otherwiseFalse
.See
is_in()
for more information.
- cpu()¶
将 TensorSpec 转换为“cpu”设备。
- cuda(device=None)¶
将 TensorSpec 转换为“cuda”设备。
- property device: Union[device, str, int]¶
规格的设备。
Only
Composite
specs can have aNone
device. All leaves must have a non-null device.
- encode(val: np.ndarray | list | torch.Tensor | TensorDictBase, *, ignore_device: bool = False) torch.Tensor | TensorDictBase ¶
使用指定的规格对值进行编码,并返回相应的张量。
此方法用于返回易于映射到 TorchRL 所需域的值(例如 numpy 数组)的环境。如果值已经是张量,则规格不会更改其值,而是按原样返回。
- 参数:
val (np.ndarray 或 torch.Tensor) – 要编码为张量的值。
- 关键字参数:
ignore_device (bool, optional) – if
True
, the spec device will be ignored. This is used to group tensor casting within a call toTensorDict(..., device="cuda")
which is faster.- 返回:
符合所需张量规格的 torch.Tensor。
- enumerate(use_mask: bool = False) torch.Tensor | TensorDictBase [source]¶
返回可以从 TensorSpec 获得的所有样本。
样本将沿第一个维度堆叠。
此方法仅为离散规格实现。
- 参数:
use_mask (bool, optional) – If
True
and the spec has a mask, samples that are masked are excluded. Default isFalse
.
- erase_memoize_cache() None ¶
清除用于缓存 encode 执行的缓存。
另请参阅
- expand(*shape)[source]¶
返回一个具有扩展形状的新 Spec。
- 参数:
*shape (tuple or iterable of int) – the new shape of the Spec. Must be broadcastable with the current shape: its length must be at least as long as the current shape length, and its last values must be compliant too; ie they can only differ from it if the current dimension is a singleton.
- flatten(start_dim: int, end_dim: int) T ¶
展平一个
TensorSpec
。Check
flatten()
for more information on this method.
- classmethod implements_for_spec(torch_function: Callable) Callable ¶
为 TensorSpec 注册一个 torch 函数覆盖。
- index(index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase) torch.Tensor | TensorDictBase [source]¶
索引输入张量。
This method is to be used with specs that encode one or more categorical variables (e.g.,
OneHot
orCategorical
), such that indexing of a tensor with a sample can be done without caring about the actual representation of the index.- 参数:
index (int, torch.Tensor, slice or list) – index of the tensor
tensor_to_index – 要索引的张量
- 返回:
被索引的张量
- 示例
>>> from torchrl.data import OneHot >>> import torch >>> >>> one_hot = OneHot(n=100) >>> categ = one_hot.to_categorical_spec() >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool) >>> idx_one_hot[50] = 1 >>> print(one_hot.index(idx_one_hot, torch.arange(100))) tensor(50) >>> idx_categ = one_hot.to_categorical(idx_one_hot) >>> print(categ.index(idx_categ, torch.arange(100))) tensor(50)
- is_in(value) bool [source]¶
If the value
val
could have been generated by theTensorSpec
, returnsTrue
, otherwiseFalse
.More precisely, the
is_in
methods checks that the valueval
is within the limits defined by thespace
attribute (the box), and that thedtype
,device
,shape
potentially other metadata match those of the spec. If any of these checks fails, theis_in
method will returnFalse
.- 参数:
val (torch.Tensor) – 要检查的值。
- 返回:
布尔值,指示值是否属于 TensorSpec 区域。
- make_neg_dim(dim: int)¶
将特定维度转换为
-1
。
- memoize_encode(mode: bool = True) None ¶
创建 encode 方法的缓存可调用序列,以加快其执行速度。
这应该只在输入类型、形状等在给定规格的调用之间预期一致时使用。
- 参数:
mode (bool, optional) – 是否使用缓存。默认为 True。
另请参阅
the cache can be erased via
erase_memoize_cache()
.
- property ndim¶
规格形状的维数。
相当于
len(spec.shape)
。
- one(shape: Optional[Size] = None) TensorDictBase ¶
返回盒中的一个填充一的张量。
注意
Even though there is no guarantee that
1
belongs to the spec domain, this method will not raise an exception when this condition is violated. The primary use case ofone
is to generate empty data buffers, not meaningful data.- 参数:
shape (torch.Size) – 一维张量的形状
- 返回:
在 TensorSpec 区域中采样的填充一的张量。
- ones(shape: torch.Size = None) torch.Tensor | TensorDictBase ¶
Proxy to
one()
.
- project(val: torch.Tensor | TensorDictBase) torch.Tensor | TensorDictBase ¶
如果输入张量不在 TensorSpec 区域内,则根据定义的启发式方法将其映射回该区域。
- 参数:
val (torch.Tensor) – 要映射到区域的张量。
- 返回:
属于 TensorSpec 区域的 torch.Tensor。
- rand(shape: Optional[Size] = None) TensorDictBase ¶
返回规格定义的区域中的随机张量。
采样将在区域内均匀进行,除非区域无界,在这种情况下将绘制正态值。
- 参数:
shape (torch.Size) – 随机张量的形状
- 返回:
在 TensorSpec 区域中采样的随机张量。
- sample(shape: torch.Size = None) torch.Tensor | TensorDictBase ¶
返回规格定义的区域中的随机张量。
See
rand()
for details.
- squeeze(dim: int | None = None)¶
返回一个新 Spec,其中所有大小为
1
的维度都已删除。当给定
dim
时,仅在该维度上执行挤压操作。- 参数:
dim (int 或 None) – 应用挤压操作的维度
- to(dest: torch.dtype | DEVICE_TYPING) T ¶
将 TensorSpec 转换为设备或 dtype。
如果未进行更改,则返回相同的规格。
- to_numpy(val: torch.Tensor, safe: bool | None = None) dict [source]¶
返回输入张量的
np.ndarray
对应项。This is intended to be the inverse operation of
encode()
.- 参数:
val (torch.Tensor) – 要转换为 numpy 的张量。
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- 返回:
一个 np.ndarray。
- type_check(value: torch.Tensor, key: NestedKey | None = None) None [source]¶
Checks the input value
dtype
against theTensorSpec
dtype
and raises an exception if they don’t match.- 参数:
value (torch.Tensor) – 必须检查其 dtype 的张量。
key (str, optional) – 如果 TensorSpec 具有键,则将检查值 dtype 是否与指定键指向的规格匹配。
- unflatten(dim: int, sizes: tuple[int]) T ¶
解展一个
TensorSpec
。Check
unflatten()
for more information on this method.
- unsqueeze(dim: int)¶
返回一个新 Spec,其中在
dim
指定的位置增加了一个单例维度。- 参数:
dim (int 或 None) – 应用 unsqueeze 操作的维度。
- zero(shape: Optional[Size] = None) TensorDictBase ¶
返回盒中的零填充张量。
注意
Even though there is no guarantee that
0
belongs to the spec domain, this method will not raise an exception when this condition is violated. The primary use case ofzero
is to generate empty data buffers, not meaningful data.- 参数:
shape (torch.Size) – 零张量的形状
- 返回:
在 TensorSpec 框中采样的零填充张量。
- zeros(shape: torch.Size = None) torch.Tensor | TensorDictBase ¶
Proxy to
zero()
.