评价此页

torch.masked#

创建于: 2022年8月15日 | 最后更新于: 2025年6月17日

引言#

动机#

警告

Masked tensors 的 PyTorch API 处于原型阶段,未来可能会更改。

MaskedTensor 作为 torch.Tensor 的扩展,赋予用户能够

  • 使用任何掩码语义(例如,可变长度张量、NaN* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(请参阅下面的教程)

“指定”和“未指定”在 PyTorch 中有着悠久的历史,但没有正式的语义,当然也没有一致性;事实上,MaskedTensor 的诞生源于 vanilla torch.Tensor 类无法妥善解决的一系列问题。因此,MaskedTensor 的主要目标是成为 PyTorch 中所述“指定”和“未指定”值的真实来源,其中它们是一等公民,而不是事后才考虑。反过来,这应该进一步释放稀疏性的潜力,实现更安全、更一致的运算符,并为用户和开发人员提供更顺畅、更直观的体验。

什么是 MaskedTensor?#

MaskedTensor 是一个张量子类,它包含 1) 一个输入(数据)和 2) 一个掩码。掩码指示应包含或忽略输入中的哪些条目。

举个例子,假设我们想掩盖所有等于 0 的值(用灰色表示)并取最大值

_images/tensor_comparison.jpg

上面是 vanilla 张量示例,下面是 MaskedTensor,其中所有 0 都被掩盖了。这显然会产生不同的结果,具体取决于我们是否有掩码,但这种灵活的结构允许用户在计算过程中系统地忽略任何他们想要的元素。

我们已经编写了许多现有教程来帮助用户上手,例如

支持的运算符#

一元运算符#

一元运算符是仅包含单个输入的运算符。将它们应用于 MaskedTensor 相对简单:如果给定索引的数据被掩盖,我们将应用运算符,否则将继续掩盖数据。

可用的 unary operators 是

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs() 的别名torch.abs()

acos

计算input中每个元素的反正弦。

arccos

torch.acos() 的别名torch.acos()

acosh

返回一个新张量,其中包含input元素的双曲反余弦。

arccosh

torch.acosh() 的别名torch.acosh()

angle

计算给定input张量的每个元素的角度(以弧度表示)。

asin

返回一个新张量,其中包含input元素的反正弦。

arcsin

torch.asin() 的别名torch.asin()

asinh

返回一个新张量,其中包含input元素的双曲反正弦。

arcsinh

torch.asinh() 的别名torch.asinh()

atan

返回一个新张量,其中包含input元素的反正切。

arctan

torch.atan() 的别名torch.atan()

atanh

返回一个新张量,其中包含input元素的双曲反正切。

arctanh

torch.atanh() 的别名torch.atanh()

bitwise_not

计算给定输入张量的按位 NOT。

ceil

返回一个新张量,其中包含input元素的向上取整值,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素限制在 [ min, max ] 范围内。

clip

torch.clamp() 的别名torch.clamp()

conj_physical

计算给定 input 张量的逐元素共轭。

cos

返回一个新张量,其中包含input元素的余弦。

cosh

返回一个新张量,其中包含input元素的双曲余弦。

deg2rad

返回一个新张量,其中input的每个元素都从角度(度)转换为弧度。

digamma

torch.special.digamma() 的别名torch.special.digamma()

erf

torch.special.erf() 的别名torch.special.erf()

erfc

torch.special.erfc() 的别名torch.special.erfc()

erfinv

torch.special.erfinv() 的别名torch.special.erfinv()

exp

返回一个新张量,其元素是输入张量input的指数。

exp2

torch.special.exp2() 的别名torch.special.exp2()

expm1

torch.special.expm1() 的别名torch.special.expm1()

fix

torch.trunc() 的别名torch.trunc()

floor

返回一个新张量,其中包含 input 元素的向下取整值,即小于或等于每个元素的最大整数。

frac

计算input中每个元素的小数部分。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 元素对应的自然对数。

log10

返回一个新张量,其中包含input元素的以10为底的对数。

log1p

返回一个新张量,其中包含(1 + input)的自然对数。

log2

返回一个新张量,其中包含input元素的以2为底的对数。

logit

torch.special.logit() 的别名torch.special.logit()

i0

torch.special.i0() 的别名torch.special.i0()

isnan

返回一个新张量,其中包含布尔元素,表示 input 中的每个元素是否为 NaN。

nan_to_num

nanposinfneginf 指定的值分别替换 input 中的 NaN、正无穷和负无穷。

neg

返回一个新张量,其中包含input元素的负值。

negative

torch.neg() 的别名torch.neg()

positive

返回 input

pow

计算 input 中每个元素以 exponent 为指数的幂,并返回结果张量。

rad2deg

返回一个新张量,其中input的每个元素都从角度(弧度)转换为度。

reciprocal

返回一个新张量,其中包含input元素的倒数。

round

input的元素四舍五入到最近的整数。

rsqrt

返回一个新张量,其中包含input每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名torch.special.expit()

sign

返回一个新张量,其中包含input元素的符号。

sgn

此函数是 torch.sign() 对复数张量的扩展。

signbit

测试input的每个元素的符号位是否已设置。

sin

返回一个新张量,其中包含input元素的正弦。

sinc

torch.special.sinc() 的别名torch.special.sinc()

sinh

返回一个新张量,其中包含input元素的双曲正弦。

sqrt

返回一个新张量,其中包含 input 元素的平方根。

square

返回一个新张量,其中包含input元素的平方。

tan

返回一个新张量,其中包含input元素的正切。

tanh

返回一个新张量,其元素是 input 的双曲正切值。

trunc

返回一个新张量,其中包含input元素的截断整数值。

可用的 inplace unary operators 是上面所有运算符,**但**

angle

计算给定input张量的每个元素的角度(以弧度表示)。

positive

返回 input

signbit

测试input的每个元素的符号位是否已设置。

isnan

返回一个新张量,其中包含布尔元素,表示 input 中的每个元素是否为 NaN。

二元运算符#

正如您可能在教程中看到的,MaskedTensor 还实现了二元运算符,但有一个条件:两个 MaskedTensor 的掩码必须匹配,否则将引发错误。如错误中所述,如果您需要特定运算符的支持或对如何处理它们有建议的语义,请在 GitHub 上开一个 issue。目前,我们已决定采用最保守的实现方式,以确保用户确切地知道正在发生什么,并对他们使用掩码语义的决定保持谨慎。

可用的 binary operators 是

add

other(缩放 alpha)加到 input

atan2

考虑象限的 inputi/otheri\text{input}_{i} / \text{other}_{i} 的逐元素反正切。

arctan2

torch.atan2() 的别名torch.atan2()

bitwise_and

计算 inputother 的按位 AND。

bitwise_or

计算 inputother 的按位 OR。

bitwise_xor

计算 inputother 的按位 XOR。

bitwise_left_shift

计算 inputother 位左移。

bitwise_right_shift

计算 inputother 位右移。

div

将输入input的每个元素除以other的相应元素。

divide

torch.div() 的别名torch.div()

floor_divide

fmod

逐元素应用 C++ 的 std::fmod

logaddexp

输入指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

mul

input乘以other

multiply

torch.mul() 的别名torch.mul()

nextafter

返回 input 之后,趋向于 other 的下一个浮点值,逐元素进行。

remainder

逐元素计算Python 的模运算

sub

input 中减去 other(缩放 alpha)。

subtract

torch.sub() 的别名torch.sub()

true_divide

torch.div() 带有 rounding_mode=None 的别名。

eq

计算逐元素相等

ne

逐元素计算 inputother\text{input} \neq \text{other}

le

逐元素计算 inputother\text{input} \leq \text{other}

ge

逐元素计算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的别名torch.gt()

greater_equal

torch.ge() 的别名torch.ge()

gt

逐元素计算 input>other\text{input} > \text{other}

less_equal

torch.le() 的别名torch.le()

lt

逐元素计算 input<other\text{input} < \text{other}

less

torch.lt() 的别名torch.lt()

maximum

计算 inputother 的逐元素最大值。

minimum

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

fmin

计算 inputother 的逐元素最小值。

not_equal

torch.ne() 的别名torch.ne()

可用的 inplace binary operators 是上面所有运算符,**但**

logaddexp

输入指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

equal

如果两个张量具有相同的大小和元素,则为True,否则为False

fmin

计算 inputother 的逐元素最小值。

minimum

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

归约#

以下归约可用(支持 autograd)。有关更多信息,概述教程详细介绍了一些归约示例,而高级语义教程对我们如何决定某些归约语义进行了更深入的讨论。

sum

返回 input 张量中所有元素的和。

mean

amin

在给定维度 dim 下,返回 input 张量每个切片的最小值。

amax

在给定维度 dim 下,返回 input 张量每个切片的最大值。

argmin

返回扁平张量或沿某一维度的最小值的索引

argmax

返回 input 张量中所有元素最大值的索引。

prod

返回input张量中所有元素的乘积。

all

测试 input 中的所有元素是否都评估为 True

norm

返回给定张量的矩阵范数或向量范数。

var

在由 dim 指定的维度上计算方差。

std

在由 dim 指定的维度上计算标准差。

视图和选择函数#

我们也包含了一些视图和选择函数;直观地说,这些运算符将同时应用于数据和掩码,然后将结果包装在 MaskedTensor 中。举个快速示例,请考虑 select()

    >>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
    >>> data
    tensor([[ 0.,  1.,  2.,  3.],
            [ 4.,  5.,  6.,  7.],
            [ 8.,  9., 10., 11.]])
    >>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
    >>> mt = masked_tensor(data, mask)
    >>> data.select(0, 1)
    tensor([4., 5., 6., 7.])
    >>> mask.select(0, 1)
    tensor([False,  True, False, False])
    >>> mt.select(0, 1)
    MaskedTensor(
      [      --,   5.0000,       --,       --]
    )

当前支持以下 ops

atleast_1d

返回每个输入张量的一维视图,其中零维度。

broadcast_tensors

根据 广播语义 广播给定的张量。

broadcast_to

input 广播到 shape 的形状。

cat

将给定的张量序列 tensors 在给定维度上连接起来。

chunk

尝试将张量分割成指定的块数。

column_stack

通过水平堆叠tensors中的张量来创建新张量。

dsplit

根据 indices_or_sections,将三维或更多维的张量 input 深度分割成多个张量。

flatten

通过将 input 重塑为一维张量来展平。

hsplit

根据 indices_or_sections,将具有一个或多个维度的张量 input 水平分割成多个张量。

hstack

按水平(列方向)顺序堆叠张量。

kron

计算 inputother 的克罗内克积,表示为 \otimes

meshgrid

创建由 attr:tensors 中一维输入指定的坐标网格。

narrow

返回一个新张量,它是 input 张量的缩小版本。

nn.functional.unfold

从批量输入张量中提取滑动局部块。

ravel

返回一个连续的展平张量。

select

沿选定维度在给定索引处对 input 张量进行切片。

split

将张量分割成块。

stack

沿新维度连接一系列张量。

t

期望input为小于等于2维的张量,并转置维度0和1。

转置

返回一个转置版本的 input 张量。

vsplit

根据 indices_or_sections,将二维或更多维的张量 input 垂直分割成多个张量。

vstack

按垂直(行方向)顺序堆叠张量。

Tensor.expand

返回 self 张量的新视图,其中单例维度已扩展到更大的大小。

Tensor.expand_as

将此张量扩展到与 other 相同的尺寸。

Tensor.reshape

返回一个具有与 self 相同的数据和相同数量的元素,但具有指定形状的张量。

Tensor.reshape_as

将此张量返回为与 other 相同的形状。

Tensor.unfold

返回原始张量的视图,该视图包含 self 张量在 dimension 维度上的所有大小为 size 的切片。

Tensor.view

返回一个新张量,它具有与 self 张量相同的数据,但形状不同。