评价此页

CausalVariant#

class torch.nn.attention.bias.CausalVariant(value)[source]#

用于注意力机制的因果变体枚举。

定义两种因果偏差类型

UPPER_LEFT: 代表标准因果注意力的左上三角偏差。用于构建此偏差的等效 pytorch 代码为

torch.tril(torch.ones(size, dtype=torch.bool))

例如,当 shape=(3,4) 时,物化偏差张量将为

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT: 代表右下三角偏差,包含值对齐到矩阵的右下角。

构造此掩码的等效 PyTorch 代码为:

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

例如,当 shape=(3,4) 时,物化偏差张量将为

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

请注意,当查询和键/值张量的序列长度相等时,这些变体是等效的,因为三角矩阵是正方形的。

警告

此枚举是一个原型,可能会发生更改。