pad¶
- class tensordict.pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0)¶
使用常量值填充 tensordict 中的所有张量(沿批处理维度),并返回一个新的 tensordict。
- 参数:
tensordict (TensorDict) – 需要填充的 tensordict
pad_size (Sequence[int]) – 用于填充 tensordict 的某些批处理维度的填充大小,从第一个维度开始向前移动。[pad_size 的长度 / 2] 个批处理大小维度将被填充。例如,要仅填充第一个维度,pad 的形式为(左填充,右填充)。要填充两个维度,则为(上左填充,上右填充,下左填充,下右填充)等等。pad_size 必须是偶数,并且小于或等于批处理维度的两倍。
value (float, optional) – 用于填充的值,默认为 0.0
- 返回:
沿批处理维度填充后的新 TensorDict
示例
>>> from tensordict import TensorDict, pad >>> import torch >>> td = TensorDict({'a': torch.ones(3, 4, 1), ... 'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4]) >>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2] >>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0) >>> print(padded_td.batch_size) torch.Size([4, 6]) >>> print(padded_td.get("a").shape) torch.Size([4, 6, 1]) >>> print(padded_td.get("b").shape) torch.Size([4, 6, 1, 1])