快捷方式

数据类型

TorchRec 包含用于表示嵌入(也称为稀疏特征)的数据类型。稀疏特征通常是用于输入嵌入表的索引。对于给定的批次,嵌入查找索引的数量是可变的。因此,需要一个**锯齿状**维度来表示批次中可变数量的嵌入查找索引。

本节介绍 TorchRec 用于表示稀疏特征的 3 种数据类型的类:**JaggedTensor**、**KeyedJaggedTensor** 和 **KeyedTensor**。

class torchrec.sparse.jagged_tensor.JaggedTensor(*args, **kwargs)

表示一个(可选加权的)锯齿状张量。

一个 JaggedTensor 是一个具有*锯齿状维度*的张量,该维度切片可能长度不同。有关完整示例,请参阅 KeyedJaggedTensor

实现是 torch.jit.script-able 的。

注意

我们不会进行输入验证,因为它很昂贵,您应该始终传入有效的长度、偏移量等。

参数:
  • values (torch.Tensor) – 密集表示的值张量。

  • weights (Optional[torch.Tensor]) – 如果值有权重。形状与 values 相同的张量。

  • lengths (Optional[torch.Tensor]) – 锯齿状切片,表示为长度。

  • offsets (Optional[torch.Tensor]) – 锯齿状切片,表示为累积偏移量。

device() device

获取 JaggedTensor 的设备。

返回:

值张量的设备。

返回类型:

torch.device

static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) JaggedTensor

构造一个空的 JaggedTensor。

参数:
  • is_weighted (bool) – JaggedTensor 是否带有权重。

  • device (Optional[torch.device]) – JaggedTensor 的设备。

  • values_dtype (Optional[torch.dtype]) – values 的 dtype。

  • weights_dtype (Optional[torch.dtype]) – weights 的 dtype。

  • lengths_dtype (torch.dtype) – lengths 的 dtype。

返回:

空的 JaggedTensor。

返回类型:

JaggedTensor

static from_dense(values: List[Tensor], weights: Optional[List[Tensor]] = None) JaggedTensor

从张量列表作为值(可选权重)构造 JaggedTensor。将计算 lengths,形状为 (B,),其中 B 是 len(values),表示批次大小。

参数:
  • values (List[torch.Tensor]) – 用于密集表示的张量列表

  • weights (Optional[List[torch.Tensor]]) – 如果值有权重,形状与 values 相同的张量。

返回:

从 2D 密集张量创建的 JaggedTensor。

返回类型:

JaggedTensor

示例

values = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
weights = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
j1 = JaggedTensor.from_dense(
    values=values,
    weights=weights,
)

# j1 = [[1.0], [], [7.0, 8.0], [10.0, 11.0, 12.0]]
static from_dense_lengths(values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None) JaggedTensor

从 values 和 lengths 张量(可选权重)构造 JaggedTensor。请注意,lengths 的形状仍为 (B,),其中 B 是批次大小。

参数:
  • values (torch.Tensor) – values 的密集表示。

  • lengths (torch.Tensor) – 锯齿状切片,表示为长度。

  • weights (Optional[torch.Tensor]) – 如果值有权重,形状与 values 相同的张量。

返回:

从 2D 密集张量创建的 JaggedTensor。

返回类型:

JaggedTensor

lengths() Tensor

获取 JaggedTensor 的 lengths。如果未计算,则从 offsets 计算。

返回:

lengths 张量。

返回类型:

torch.Tensor

lengths_or_none() Optional[Tensor]

获取 JaggedTensor 的 lengths。如果未计算,则返回 None。

返回:

lengths 张量。

返回类型:

Optional[torch.Tensor]

offsets() Tensor

获取 JaggedTensor 的 offsets。如果未计算,则从 lengths 计算。

返回:

offsets 张量。

返回类型:

torch.Tensor

offsets_or_none() Optional[Tensor]

获取 JaggedTensor 的 offsets。如果未计算,则返回 None。

返回:

offsets 张量。

返回类型:

Optional[torch.Tensor]

record_stream(stream: Stream) None

参见 https://pytorch.ac.cn/docs/stable/generated/torch.Tensor.record_stream.html

to(device: device, non_blocking: bool = False) JaggedTensor

将 JaggedTensor 移动到指定的设备。

参数:
  • device (torch.device) – 要移动到的设备。

  • non_blocking (bool) – 是否异步执行复制。

返回:

移动后的 JaggedTensor。

返回类型:

JaggedTensor

to_dense() List[Tensor]

构造 JT 值的密集表示。

返回:

张量列表。

返回类型:

List[torch.Tensor]

示例

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

values_list = jt.to_dense()

# values_list = [
#     torch.tensor([1.0, 2.0]),
#     torch.tensor([]),
#     torch.tensor([3.0]),
#     torch.tensor([4.0]),
#     torch.tensor([5.0]),
#     torch.tensor([6.0, 7.0, 8.0]),
# ]
to_dense_weights() Optional[List[Tensor]]

构造 JT 权重的密集表示。

返回:

张量列表,如果无权重则为 None

返回类型:

Optional[List[torch.Tensor]]

示例

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

weights_list = jt.to_dense_weights()

# weights_list = [
#     torch.tensor([0.1, 0.2]),
#     torch.tensor([]),
#     torch.tensor([0.3]),
#     torch.tensor([0.4]),
#     torch.tensor([0.5]),
#     torch.tensor([0.6, 0.7, 0.8]),
# ]
to_padded_dense(desired_length: Optional[int] = None, padding_value: float = 0.0) Tensor

从 JT 值的形状为 (B, N,) 的密集张量构造形状为 (B, N,) 的 2D 密集张量。

请注意,B 是 self.lengths() 的长度,而 N 是最长的特征长度或 desired_length

如果 desired_length > length,我们将用 padding_value 填充,否则我们将选择 desired_length 处的最后一个值。

参数:
  • desired_length (int) – 张量的长度。

  • padding_value (float) – 如果需要填充,则为填充值。

返回:

2d 密集张量。

返回类型:

torch.Tensor

示例

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

dt = jt.to_padded_dense(
    desired_length=2,
    padding_value=10.0,
)

# dt = [
#     [1.0, 2.0],
#     [10.0, 10.0],
#     [3.0, 10.0],
#     [4.0, 10.0],
#     [5.0, 10.0],
#     [6.0, 7.0],
# ]
to_padded_dense_weights(desired_length: Optional[int] = None, padding_value: float = 0.0) Optional[Tensor]

从 JT 权重的形状为 (B, N,) 的 2D 密集张量构造形状为 (B, N,) 的 2D 密集张量。

请注意,B (批次大小) 是 self.lengths() 的长度,而 N 是最长的特征长度或 desired_length

如果 desired_length > length,我们将用 padding_value 填充,否则我们将选择 desired_length 处的最后一个值。

类似于 to_padded_dense,但用于 JT 的权重而不是值。

参数:
  • desired_length (int) – 张量的长度。

  • padding_value (float) – 如果需要填充,则为填充值。

返回:

2d 密集张量,如果无权重则为 None

返回类型:

Optional[torch.Tensor]

示例

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

d_wt = jt.to_padded_dense_weights(
    desired_length=2,
    padding_value=1.0,
)

# d_wt = [
#     [0.1, 0.2],
#     [1.0, 1.0],
#     [0.3, 1.0],
#     [0.4, 1.0],
#     [0.5, 1.0],
#     [0.6, 0.7],
# ]
values() Tensor

获取 JaggedTensor 的 values。

返回:

values 张量。

返回类型:

torch.Tensor

weights() Tensor

获取 JaggedTensor 的 weights。如果为 None,则抛出错误。

返回:

weights 张量。

返回类型:

torch.Tensor

weights_or_none() Optional[Tensor]

获取 JaggedTensor 的 weights。如果为 None,则返回 None。

返回:

weights 张量。

返回类型:

Optional[torch.Tensor]

class torchrec.sparse.jagged_tensor.KeyedJaggedTensor(*args, **kwargs)

表示一个(可选加权的)键控锯齿状张量。

一个 KeyedJaggedTensor 是一个具有*锯齿状维度*的张量,该维度切片可能长度不同。按第一维键控,按最后一维锯齿状。

实现是 torch.jit.script-able 的。

参数:
  • keys (List[str]) – 锯齿状张量的键。

  • values (torch.Tensor) – 密集表示的值张量。

  • weights (Optional[torch.Tensor]) – 如果值有权重。形状与 values 相同的张量。

  • lengths (Optional[torch.Tensor]) – 锯齿状切片,表示为长度。

  • offsets (Optional[torch.Tensor]) – 锯齿状切片,表示为累积偏移量。

  • stride (Optional[int]) – 每个批次的样本数。

  • stride_per_key_per_rank (Optional[Union[torch.IntTensor, List[List[int]]]]) – 每个键每个 rank 的批次大小(样本数),外层列表代表键,内层列表代表值。内层列表中的每个值表示在分布式上下文中来自其索引的 rank 的批次中的样本数。

  • length_per_key (Optional[List[int]]) – 每个键的起始长度。

  • offset_per_key (Optional[List[int]]) – 每个键的起始偏移量和最终偏移量。

  • index_per_key (Optional[Dict[str, int]]) – 每个键的索引。

  • jt_dict (Optional[Dict[str, JaggedTensor]]) – 键到 JaggedTensors 的字典。允许 `to_dict()` 懒加载/可缓存。

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – 用于展开去重嵌入输出以进行每个键的可变 stride 的逆索引。

示例

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

# We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7]  # V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7]  # W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3]  # representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8]  # offsets from 0 for each jagged slice
keys: List[str] = ["Feature0", "Feature1"]  # correspond to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}  # index for each key
offset_per_key: List[int] = [0, 3, 8]  # start offset for each key and final offset
static concat(kjt_list: List[KeyedJaggedTensor]) KeyedJaggedTensor

将 KeyedJaggedTensors 列表连接成一个 KeyedJaggedTensor。

参数:

kjt_list (List[KeyedJaggedTensor]) – 要连接的 KeyedJaggedTensors 列表。

返回:

连接后的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

device() device

返回 KeyedJaggedTensor 的设备。

返回:

KeyedJaggedTensor 的设备。

返回类型:

torch.device

static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) KeyedJaggedTensor

构造一个空的 KeyedJaggedTensor。

参数:
  • is_weighted (bool) – KeyedJaggedTensor 是否加权。

  • device (Optional[torch.device]) – KeyedJaggedTensor 将放置的设备。

  • values_dtype (Optional[torch.dtype]) – values 张量的 dtype。

  • weights_dtype (Optional[torch.dtype]) – weights 张量的 dtype。

  • lengths_dtype (torch.dtype) – lengths 张量的 dtype。

返回:

空的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

static empty_like(kjt: KeyedJaggedTensor) KeyedJaggedTensor

构造一个与输入 KeyedJaggedTensor 具有相同设备和 dtype 的空 KeyedJaggedTensor。

参数:

kjt (KeyedJaggedTensor) – 输入的 KeyedJaggedTensor。

返回:

空的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

static from_jt_dict(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor

从 JaggedTensors 字典构造 KeyedJaggedTensor。会自动对新创建的 KJT 调用 kjt.sync()

注意

此函数**仅**在 JaggedTensors 具有相同的“隐式”批次大小维度时才有效。

基本上,我们可以将 JaggedTensors 可视化为 [batch_size x variable_feature_dim] 格式的 2D 张量。在某些批次没有特征值的情况下,输入的 JaggedTensor 甚至可以不包含任何值。

但是 KeyedJaggedTensor(默认情况下)通常会填充“None”,以便 KeyedJaggedTensor 中存储的所有 JaggedTensors 具有相同的批次大小维度。也就是说,在这种情况下,如果输入的 JaggedTensor 未能自动为非批次填充,则此函数将出错/无法正常工作。

考虑以下 KeyedJaggedTensor 的可视化: # 0 1 2 <– dim_1 # “Feature0” [V0,V1] None [V2] # “Feature1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

现在,如果输入的 jt_dict = {

# “Feature0” [V0,V1] [V2] # “Feature1” [V3] [V4] [V5,V6,V7]

} 并且每个 JaggedTensor 都省略了“None”,那么此函数将失败,因为我们无法正确地填充“None”,因为它实际上不知道批次的大小/位置来填充 JaggedTensor。

本质上,此函数推断出的 lengths Tensor 为 [2, 1, 1, 1, 3],这表示 batch_size 的 dim_1 是可变的,这违反了 KeyedJaggedTensor 应该具有固定的 batch_size 维度的现有假设/前提条件。

参数:

jt_dict (Dict[str, JaggedTensor]) – JaggedTensors 的字典。

返回:

构造的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

static from_lengths_sync(keys: List[str], values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor

从键列表、长度和偏移量构造 KeyedJaggedTensor。与 from_offsets_sync 相同,但使用长度而不是偏移量。

参数:
  • keys (List[str]) – 键的列表。

  • values (torch.Tensor) – 密集表示的值张量。

  • lengths (torch.Tensor) – 锯齿状切片,表示为长度。

  • weights (Optional[torch.Tensor]) – 如果值有权重。形状与 values 相同的张量。

  • stride (Optional[int]) – 每个批次的样本数。

  • stride_per_key_per_rank (Optional[List[List[int]]]) – 每个 rank 的每个键的 batch 大小(样本数量),外层列表代表键,内层列表代表值。

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – 用于展开去重嵌入输出以进行每个键的可变 stride 的逆索引。

返回:

构造的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

static from_offsets_sync(keys: List[str], values: Tensor, offsets: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor

从键列表、值和偏移量构造 KeyedJaggedTensor。

参数:
  • keys (List[str]) – 键的列表。

  • values (torch.Tensor) – 密集表示的值张量。

  • offsets (torch.Tensor) – 锯齿状切片,表示为累积偏移量。

  • weights (Optional[torch.Tensor]) – 如果值有权重。形状与 values 相同的张量。

  • stride (Optional[int]) – 每个批次的样本数。

  • stride_per_key_per_rank (Optional[List[List[int]]]) – 每个 rank 的每个键的 batch 大小(样本数量),外层列表代表键,内层列表代表值。

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – 用于展开去重嵌入输出以进行每个键的可变 stride 的逆索引。

返回:

构造的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

index_per_key() Dict[str, int]

返回 KeyedJaggedTensor 的每个键的索引。

返回:

KeyedJaggedTensor 的每个键的索引。

返回类型:

Dict[str, int]

inverse_indices() Tuple[List[str], Tensor]

返回 KeyedJaggedTensor 的逆索引。如果逆索引为 None,则会引发错误。

返回:

KeyedJaggedTensor 的逆索引。

返回类型:

Tuple[List[str], torch.Tensor]

inverse_indices_or_none() Optional[Tuple[List[str], Tensor]]

返回 KeyedJaggedTensor 的逆索引,如果不存在则返回 None。

返回:

KeyedJaggedTensor 的逆索引。

返回类型:

Optional[Tuple[List[str], torch.Tensor]]

keys() List[str]

返回 KeyedJaggedTensor 的键。

返回:

KeyedJaggedTensor 的键。

返回类型:

List[str]

length_per_key() List[int]

返回 KeyedJaggedTensor 的每个键的长度。如果每个键的长度为 None,则会计算它。

返回:

KeyedJaggedTensor 的每个键的长度。

返回类型:

List[int]

length_per_key_or_none() Optional[List[int]]

返回 KeyedJaggedTensor 的每个键的长度,如果尚未计算则返回 None。

返回:

KeyedJaggedTensor 的每个键的长度。

返回类型:

List[int]

lengths() Tensor

返回 KeyedJaggedTensor 的长度。如果长度尚未计算,则会计算它。

返回:

KeyedJaggedTensor 的长度。

返回类型:

torch.Tensor

lengths_offset_per_key() List[int]

返回 KeyedJaggedTensor 的每个键的长度偏移量。如果每个键的长度偏移量为 None,则会计算它。

返回:

KeyedJaggedTensor 的每个键的长度偏移量。

返回类型:

List[int]

lengths_or_none() Optional[Tensor]

返回 KeyedJaggedTensor 的长度,如果尚未计算则返回 None。

返回:

KeyedJaggedTensor 的长度。

返回类型:

torch.Tensor

offset_per_key() List[int]

返回 KeyedJaggedTensor 的每个键的偏移量。如果每个键的偏移量为 None,则会计算它。

返回:

KeyedJaggedTensor 的每个键的偏移量。

返回类型:

List[int]

offset_per_key_or_none() Optional[List[int]]

返回 KeyedJaggedTensor 的每个键的偏移量,如果尚未计算则返回 None。

返回:

KeyedJaggedTensor 的每个键的偏移量。

返回类型:

List[int]

offsets() Tensor

返回 KeyedJaggedTensor 的偏移量。如果偏移量尚未计算,则会计算它。

返回:

KeyedJaggedTensor 的偏移量。

返回类型:

torch.Tensor

offsets_or_none() Optional[Tensor]

返回 KeyedJaggedTensor 的偏移量,如果尚未计算则返回 None。

返回:

KeyedJaggedTensor 的偏移量。

返回类型:

torch.Tensor

permute(indices: List[int], indices_tensor: Optional[Tensor] = None) KeyedJaggedTensor

置换 KeyedJaggedTensor。

参数:
  • indices (List[int]) – 索引列表。

  • indices_tensor (Optional[torch.Tensor]) – 索引张量。

返回:

置换后的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

record_stream(stream: Stream) None

参见 https://pytorch.ac.cn/docs/stable/generated/torch.Tensor.record_stream.html

split(segments: List[int]) List[KeyedJaggedTensor]

将 KeyedJaggedTensor 拆分为 KeyedJaggedTensor 列表。

参数:

segments (List[int]) – 分段列表。

返回:

KeyedJaggedTensor 列表。

返回类型:

List[KeyedJaggedTensor]

stride() int

返回 KeyedJaggedTensor 的步幅。如果步幅为 None,则会计算它。

返回:

KeyedJaggedTensor 的步幅。

返回类型:

int

stride_per_key() List[int]

返回 KeyedJaggedTensor 的每个键的步幅。如果每个键的步幅为 None,则会计算它。

返回:

KeyedJaggedTensor 的每个键的步幅。

返回类型:

List[int]

stride_per_key_per_rank() List[List[int]]

返回 KeyedJaggedTensor 的每个 rank 的每个键的步幅。

返回:

KeyedJaggedTensor 的每个 rank 的每个键的步幅。

返回类型:

List[List[int]]

sync() KeyedJaggedTensor

通过计算 offset_per_key 和 length_per_key 来同步 KeyedJaggedTensor。

返回:

同步后的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

to(device: device, non_blocking: bool = False, dtype: Optional[dtype] = None) KeyedJaggedTensor

返回指定设备和数据类型的 KeyedJaggedTensor 的副本。

参数:
  • device (torch.device) – 副本的期望设备。

  • non_blocking (bool) – 是否以非阻塞方式复制张量。

  • dtype (Optional[torch.dtype]) – 副本的期望数据类型。

返回:

复制的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

to_dict() Dict[str, JaggedTensor]

返回每个键的 JaggedTensor 字典。结果将缓存到 self._jt_dict 中。

返回:

每个键的 JaggedTensor 字典。

返回类型:

Dict[str, JaggedTensor]

unsync() KeyedJaggedTensor

通过清除 offset_per_key 和 length_per_key 来取消同步 KeyedJaggedTensor。

返回:

未同步的 KeyedJaggedTensor。

返回类型:

KeyedJaggedTensor

values() Tensor

返回 KeyedJaggedTensor 的值。

返回:

KeyedJaggedTensor 的值。

返回类型:

torch.Tensor

variable_stride_per_key() bool

返回 KeyedJaggedTensor 是否具有每个键的可变步幅。注意:当 self._stride_per_key_per_rank 不为 None 时,self._variable_stride_per_key 可能为 False。它可能被外部/有意设置为 False,通常 self._stride_per_key_per_rank 是微不足道的。

返回:

KeyedJaggedTensor 是否具有每个键的可变步幅。

返回类型:

布尔值

weights() Tensor

返回 KeyedJaggedTensor 的权重。如果权重为 None,则会引发错误。

返回:

KeyedJaggedTensor 的权重。

返回类型:

torch.Tensor

weights_or_none() Optional[Tensor]

返回 KeyedJaggedTensor 的权重,如果不存在则返回 None。

返回:

KeyedJaggedTensor 的权重。

返回类型:

torch.Tensor

class torchrec.sparse.jagged_tensor.KeyedTensor(*args, **kwargs)

KeyedTensor 保存一个连接的密集张量列表,每个张量都可以通过键访问。

键维度可以是可变长度的 (length_per_key)。常见用例包括存储不同维度的池化嵌入。

实现是 torch.jit.script-able 的。

参数:
  • keys (List[str]) – 键的列表。

  • length_per_key (List[int]) – 键维度上每个键的长度。

  • values (torch.Tensor) – 密集张量,通常沿键维度连接。

  • key_dim (int) – 键维度,从零开始索引 - 默认为 1(通常 B 是 0 维度)。

示例

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]

tensor_list = [
    torch.tensor([[1,1]] * 3),
    torch.tensor([[2,1,2]] * 3),
    torch.tensor([[3,1,2,3]] * 3),
]

keys = ["Embedding A", "Embedding B", "Embedding C"]

kt = KeyedTensor.from_tensor_list(keys, tensor_list)

kt.values()
# torch.Tensor(
#     [
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#     ]
# )

kt["Embedding B"]
# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
device() device
返回:

值张量的设备。

返回类型:

torch.device

static from_tensor_list(keys: List[str], tensors: List[Tensor], key_dim: int = 1, cat_dim: int = 1) KeyedTensor

从张量列表创建 KeyedTensor。张量沿 cat_dim 连接。键用于索引张量。

参数:
  • keys (List[str]) – 键的列表。

  • tensors (List[torch.Tensor]) – 张量列表。

  • key_dim (int) – 键维度,从零开始索引 - 默认为 1(通常 B 是 0 维度)。

  • cat_dim (int) – 连接张量的维度 - 默认为

返回:

键控张量。

返回类型:

KeyedTensor

key_dim() int
返回:

键维度,从零开始索引 - 通常 B 是 0 维度。

返回类型:

int

keys() List[str]
返回:

键的列表。

返回类型:

List[str]

length_per_key() List[int]
返回:

键维度上每个键的长度。

返回类型:

List[int]

offset_per_key() List[int]

获取键维度上每个键的偏移量。如果尚未计算,则计算并缓存。

返回:

键维度上每个键的偏移量。

返回类型:

List[int]

record_stream(stream: Stream) None

参见 https://pytorch.ac.cn/docs/stable/generated/torch.Tensor.record_stream.html

static regroup(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor]

将 KeyedTensors 列表重新组合为张量列表。

参数:
  • keyed_tensors (List[KeyedTensor]) – KeyedTensors 列表。

  • groups (List[List[str]]) – 键的组列表。

返回:

张量列表。

返回类型:

List[torch.Tensor]

static regroup_as_dict(keyed_tensors: List[KeyedTensor], groups: List[List[str]], keys: List[str]) Dict[str, Tensor]

将 KeyedTensors 列表重新组合为张量字典。

参数:
  • keyed_tensors (List[KeyedTensor]) – KeyedTensors 列表。

  • groups (List[List[str]]) – 键的组列表。

  • keys (List[str]) – 键的列表。

返回:

张量字典。

返回类型:

Dict[str, torch.Tensor]

to(device: device, non_blocking: bool = False) KeyedTensor

将值张量移动到指定的设备。

参数:
  • device (torch.device) – 将值张量移动到的设备。

  • non_blocking (bool) – 是否异步执行操作(默认:False)。

返回:

值张量已移动到指定设备的键控张量。

返回类型:

KeyedTensor

to_dict() Dict[str, Tensor]
返回:

由键键控的张量字典。

返回类型:

Dict[str, torch.Tensor]

values() Tensor

获取值张量。

返回:

密集张量,通常沿键维度连接。

返回类型:

torch.Tensor

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源