评价此页

torch#

创建于:2016年12月23日 | 最后更新于:2025年3月10日

torch 包包含多维张量的数据结构,并定义了这些张量上的数学运算。此外,它还提供了许多用于高效序列化张量和任意类型,以及其他实用工具。

它有一个 CUDA 对应物,使您能够在计算能力 >= 3.0 的 NVIDIA GPU 上运行张量计算。

张量#

is_tensor

如果 obj 是 PyTorch 张量,则返回 True。

is_storage

如果 obj 是 PyTorch 存储对象,则返回 True。

is_complex

如果 input 的数据类型是复数数据类型,即 torch.complex64torch.complex128 之一,则返回 True。

is_conj

如果 input 是共轭张量,即其共轭位设置为 True,则返回 True。

is_floating_point

如果 input 的数据类型是浮点数据类型,即 torch.float64torch.float32torch.float16torch.bfloat16 之一,则返回 True。

is_nonzero

如果 input 是一个单元素张量,并且在类型转换后不等于零,则返回 True。

set_default_dtype

将默认浮点 dtype 设置为 d

get_default_dtype

获取当前默认浮点 torch.dtype

set_default_device

设置默认 torch.Tensordevice 上分配。

get_default_device

获取默认 torch.Tensordevice 上分配。

set_default_tensor_type

numel

返回 input 张量中的元素总数。

set_printoptions

设置打印选项。

set_flush_denormal

禁用 CPU 上的非正常浮点数。

创建操作#

注意

随机采样创建操作列在 随机采样 下,包括: torch.rand() torch.rand_like() torch.randn() torch.randn_like() torch.randint() torch.randint_like() torch.randperm() 您还可以使用 torch.empty()原地随机采样 方法来创建值从更广泛的分布中采样的 torch.Tensor

tensor

通过复制 data 构造一个没有自动求导历史的张量(也称为“叶子张量”,参见 自动求导机制)。

sparse_coo_tensor

构造一个 COO(坐标)格式的稀疏张量,在给定 indices 处指定值。

sparse_csr_tensor

构造一个 CSR(压缩稀疏行)格式的稀疏张量,在给定 crow_indicescol_indices 处指定值。

sparse_csc_tensor

构造一个 CSC(压缩稀疏列)格式的稀疏张量,在给定 ccol_indicesrow_indices 处指定值。

sparse_bsr_tensor

构造一个 BSR(块压缩稀疏行)格式的稀疏张量,在给定 crow_indicescol_indices 处指定 2 维块。

sparse_bsc_tensor

构造一个 BSC(块压缩稀疏列)格式的稀疏张量,在给定 ccol_indicesrow_indices 处指定 2 维块。

asarray

obj 转换为张量。

as_tensor

data 转换为张量,如果可能,共享数据并保留自动求导历史记录。

as_strided

创建现有 torch.Tensor input 的视图,具有指定的 sizestridestorage_offset

from_file

创建一个由内存映射文件支持的存储的 CPU 张量。

from_numpy

numpy.ndarray 创建 Tensor

from_dlpack

将外部库中的张量转换为 torch.Tensor

frombuffer

从实现 Python 缓冲区协议的对象创建 1 维 Tensor

zeros

返回一个用标量值 0 填充的张量,其形状由可变参数 size 定义。

zeros_like

返回一个用标量值 0 填充的张量,其大小与 input 相同。

ones

返回一个用标量值 1 填充的张量,其形状由可变参数 size 定义。

ones_like

返回一个用标量值 1 填充的张量,其大小与 input 相同。

arange

返回一个大小为 endstartstep\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil 的 1 维张量,其值从区间 [start, end) 中以 step 为公差从 start 开始取值。

range

返回一个大小为 endstartstep+1\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1 的 1 维张量,其值从 startend,步长为 step

linspace

创建一个一维张量,大小为 steps,其值从 startend 均匀分布,包括端点。

logspace

创建一个大小为steps的一维张量,其值在对数刻度上从basestart{{\text{{base}}}}^{{\text{{start}}}}baseend{{\text{{base}}}}^{{\text{{end}}}}(包含两端)均匀分布,以base为底。

eye

返回一个对角线为1,其他地方为0的二维张量。

empty

返回一个填充了未初始化数据的张量。

empty_like

返回一个未初始化的张量,其大小与input相同。

empty_strided

创建一个具有指定sizestride并填充未定义数据的张量。

full

创建一个大小为size,并用fill_value填充的张量。

full_like

返回一个与input大小相同,并用fill_value填充的张量。

quantize_per_tensor

将浮点张量转换为具有给定缩放和零点的量化张量。

quantize_per_channel

将浮点张量转换为具有给定缩放和零点的逐通道量化张量。

dequantize

通过对量化张量进行反量化,返回一个fp32张量。

complex

构建一个复数张量,其实部等于real,虚部等于imag

polar

构建一个复数张量,其元素是与具有绝对值abs和角度angle的极坐标相对应的笛卡尔坐标。

heaviside

计算input中每个元素的Heaviside阶跃函数。

索引、切片、连接、变异操作#

adjoint

返回共轭张量的视图,并转置最后两个维度。

argwhere

返回一个包含input中所有非零元素索引的张量。

cat

在给定维度中,连接tensors中给定的张量序列。

concat

torch.cat()的别名。

concatenate

torch.cat()的别名。

conj

返回input的视图,其共轭位已翻转。

chunk

尝试将张量分割成指定数量的块。

dsplit

根据indices_or_sections,将input(一个三维或更多维的张量)在深度方向上分割成多个张量。

column_stack

通过水平堆叠tensors中的张量来创建新张量。

dstack

按深度(沿第三轴)顺序堆叠张量。

gather

沿由 dim 指定的轴收集值。

hsplit

根据indices_or_sections,将input(一个一维或更多维的张量)水平分割成多个张量。

hstack

按水平(列方向)顺序堆叠张量。

index_add

请参阅index_add_()以获取函数说明。

index_copy

请参阅index_add_()以获取函数说明。

index_reduce

请参阅index_reduce_()以获取函数说明。

index_select

返回一个新张量,该张量使用LongTensor中的index条目沿维度diminput张量进行索引。

masked_select

返回一个新的一维张量,该张量根据BoolTensor类型的布尔掩码maskinput张量进行索引。

movedim

inputsource中的维度移动到destination中的位置。

moveaxis

torch.movedim()的别名。

narrow

返回一个新张量,它是input张量的窄化版本。

narrow_copy

Tensor.narrow()相同,只是它返回一个副本而不是共享存储。

nonzero

permute

返回原始张量input的视图,其维度已置换。

reshape

返回一个与input具有相同数据和元素数量,但具有指定形状的张量。

row_stack

torch.vstack()的别名。

select

沿选定维度在给定索引处切片input张量。

scatter

torch.Tensor.scatter_()的非就地版本

diagonal_scatter

相对于dim1dim2,将src张量的值嵌入到input的对角线元素中。

select_scatter

在给定索引处,将src张量的值嵌入到input中。

slice_scatter

在给定维度处,将src张量的值嵌入到input中。

scatter_add

torch.Tensor.scatter_add_()的非就地版本

scatter_reduce

torch.Tensor.scatter_reduce_()的非就地版本

split

将张量分割成块。

squeeze

返回一个张量,其中input的所有指定维度(大小为1)都被移除。

stack

沿新维度连接一系列张量。

swapaxes

torch.transpose()的别名。

swapdims

torch.transpose()的别名。

t

期望input为小于等于2维的张量,并转置维度0和1。

take

返回一个新张量,其中包含input在给定索引处的元素。

take_along_dim

沿给定diminput中选择1维索引indices处的值。

tensor_split

根据indices_or_sections指定的索引或段数,将张量沿维度dim分割成多个子张量,所有子张量都是input的视图。

tile

通过重复input的元素来构造一个张量。

转置

返回一个张量,它是input的转置版本。

unbind

移除一个张量维度。

unravel_index

将扁平索引张量转换为坐标张量元组,这些坐标张量可以索引到指定形状的任意张量中。

unsqueeze

返回一个新张量,在指定位置插入大小为1的维度。

vsplit

根据indices_or_sections,将input(一个二维或更多维的张量)垂直分割成多个张量。

vstack

按垂直(行方向)顺序堆叠张量。

其中

根据condition,返回一个从inputother中选择元素的张量。

加速器#

在 PyTorch 仓库中,我们将“加速器”定义为与 CPU 一起使用以加速计算的torch.device。这些设备使用异步执行方案,以torch.Streamtorch.Event作为其主要的同步方式。我们还假设在给定主机上一次只能有一个此类加速器可用。这使得我们可以将当前加速器用作相关概念(如固定内存、Stream 设备类型、FSDP 等)的默认设备。

目前,加速器设备包括(不分先后顺序)“CUDA”“MTIA”“XPU”“MPS”、“HPU”和 PrivateUse1(PyTorch 仓库本身没有的许多设备)。

PyTorch 生态系统中的许多工具使用 fork 创建子进程(例如数据加载或操作内并行),因此尽可能延迟任何会阻止进一步 fork 的操作非常重要。这在这里尤其重要,因为大多数加速器的初始化都有这种效果。实际上,您应该记住,默认情况下检查torch.accelerator.current_accelerator()是编译时检查,因此它始终是 fork-safe 的。相反,将check_available=True标志传递给此函数或调用torch.accelerator.is_available()通常会阻止后续的 fork。

一些后端提供了实验性的可选功能,使运行时可用性检查成为 fork 安全的。例如,使用 CUDA 设备时,可以使用PYTORCH_NVML_BASED_CUDA_CHECK=1

一个按先进先出(FIFO)顺序异步执行相应任务的有序队列。

事件

查询和记录流状态以识别或控制流之间的依赖关系并测量时间。

生成器#

生成器

创建并返回一个生成器对象,该对象管理生成伪随机数的算法状态。

随机采样#

seed

将所有设备上生成随机数的种子设置为非确定性随机数。

manual_seed

设置所有设备上生成随机数的种子。

initial_seed

返回生成随机数的初始种子,作为Python long类型。

get_rng_state

将随机数生成器状态作为torch.ByteTensor返回。

set_rng_state

设置随机数生成器状态。

torch.default_generator 返回默认的CPU torch.Generator#

bernoulli

从伯努利分布中抽取二元随机数(0或1)。

multinomial

返回一个张量,其中每一行包含从位于张量input对应行中的多项式(更严格的定义是多元,有关更多详细信息,请参阅torch.distributions.multinomial.Multinomial)概率分布中采样的num_samples个索引。

normal

返回一个张量,其中包含从给定均值和标准差的独立正态分布中抽取的随机数。

poisson

返回一个与input大小相同的张量,其中每个元素都从泊松分布中采样,其速率参数由input中的相应元素给出,即

rand

返回一个张量,其中填充了从区间[0,1)[0, 1)上的均匀分布中随机生成的数字。

rand_like

返回一个与input大小相同的张量,其中填充了从区间[0,1)[0, 1)上的均匀分布中随机生成的数字。

randint

返回一个张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。

randint_like

返回一个与张量input形状相同的张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。

randn

返回一个张量,其中填充了均值为0且方差为1(也称为标准正态分布)的正态分布中的随机数。

randn_like

返回一个与input大小相同的张量,其中填充了均值为0且方差为1的正态分布中的随机数。

randperm

返回从0n - 1的随机排列。

就地随机采样#

张量上还定义了一些就地随机采样函数。点击查看其文档

准随机采样#

quasirandom.SobolEngine

torch.quasirandom.SobolEngine是生成(扰动)Sobol序列的引擎。

序列化#

保存

将对象保存到磁盘文件。

加载

从文件中加载使用torch.save()保存的对象。

并行#

get_num_threads

返回用于并行化CPU操作的线程数

set_num_threads

设置用于CPU内部并行操作的线程数。

get_num_interop_threads

返回用于CPU之间操作并行化的线程数(例如

set_num_interop_threads

设置用于操作间并行化的线程数(例如

局部禁用梯度计算#

上下文管理器torch.no_grad()torch.enable_grad()torch.set_grad_enabled()有助于局部禁用和启用梯度计算。有关其用法的更多详细信息,请参阅局部禁用梯度计算。这些上下文管理器是线程局部,因此如果您使用threading模块等将工作发送到另一个线程,它们将不起作用。

示例

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

no_grad

禁用梯度计算的上下文管理器。

enable_grad

启用梯度计算的上下文管理器。

autograd.grad_mode.set_grad_enabled

设置梯度计算开关的上下文管理器。

is_grad_enabled

如果当前启用了梯度模式,则返回True。

autograd.grad_mode.inference_mode

启用或禁用推理模式的上下文管理器。

is_inference_mode_enabled

如果当前启用了推理模式,则返回True。

数学运算#

常数#

inf

浮点正无穷大。math.inf的别名。

nan

浮点数“非数字”值。此值不是合法数字。math.nan的别名。

逐点操作#

abs

计算 input 中每个元素的绝对值。

absolute

torch.abs()的别名

acos

计算input中每个元素的反正弦。

arccos

torch.acos()的别名。

acosh

返回一个新张量,其中包含input元素的双曲反余弦。

arccosh

torch.acosh()的别名。

add

other(按alpha缩放)添加到input

addcdiv

tensor1除以tensor2的每个元素进行运算,将结果乘以标量value,并将其添加到input中。

addcmul

tensor1乘以tensor2的每个元素进行运算,将结果乘以标量value,并将其添加到input中。

angle

计算给定input张量的每个元素的角度(以弧度表示)。

asin

返回一个新张量,其中包含input元素的反正弦。

arcsin

torch.asin()的别名。

asinh

返回一个新张量,其中包含input元素的双曲反正弦。

arcsinh

torch.asinh()的别名。

atan

返回一个新张量,其中包含input元素的反正切。

arctan

torch.atan()的别名。

atanh

返回一个新张量,其中包含input元素的双曲反正切。

arctanh

torch.atanh()的别名。

atan2

元素级反正切inputi/otheri\text{input}_{i} / \text{other}_{i},考虑到象限。

arctan2

torch.atan2()的别名。

bitwise_not

计算给定输入张量的按位非。

bitwise_and

计算inputother的按位与。

bitwise_or

计算inputother的按位或。

bitwise_xor

计算inputother的按位异或。

bitwise_left_shift

计算input左移other位的算术移位。

bitwise_right_shift

计算input右移other位的算术移位。

ceil

返回一个新张量,其中包含input元素的向上取整值,即大于或等于每个元素的最小整数。

clamp

input中的所有元素钳制到范围[ min, max ]内。

clip

torch.clamp()的别名。

conj_physical

计算给定input张量的每个元素的共轭。

copysign

逐元素地创建一个新的浮点张量,其大小与input相同,符号与other相同。

cos

返回一个新张量,其中包含input元素的余弦。

cosh

返回一个新张量,其中包含input元素的双曲余弦。

deg2rad

返回一个新张量,其中input的每个元素都从角度(度)转换为弧度。

div

将输入input的每个元素除以other的相应元素。

divide

torch.div()的别名。

digamma

torch.special.digamma()的别名。

erf

torch.special.erf()的别名。

erfc

torch.special.erfc()的别名。

erfinv

torch.special.erfinv()的别名。

exp

返回一个新张量,其元素是输入张量input的指数。

exp2

torch.special.exp2()的别名。

expm1

torch.special.expm1()的别名。

fake_quantize_per_channel_affine

返回一个新张量,其中input中的数据通过scalezero_pointquant_minquant_max沿axis指定的通道进行假量化。

fake_quantize_per_tensor_affine

返回一个新张量,其中input中的数据通过scalezero_pointquant_minquant_max进行假量化。

fix

torch.trunc()的别名

float_power

以双精度逐元素计算inputexponent次方。

floor

返回一个新张量,其中包含 input 元素的向下取整值,即小于或等于每个元素的最大整数。

floor_divide

fmod

逐条应用C++的std::fmod

frac

计算input中每个元素的小数部分。

frexp

input分解为尾数和指数张量,使得input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}

gradient

使用二阶中心差分法以及边界处的一阶或二阶估计,估算一维或多维函数g:RnRg : \mathbb{R}^n \rightarrow \mathbb{R}的梯度。

imag

返回一个新张量,其中包含self张量的虚部值。

ldexp

input乘以2 ** other

lerp

根据标量或张量weight对两个张量start(由input给出)和end进行线性插值,并返回结果out张量。

lgamma

计算 input 上伽马函数绝对值的自然对数。

log

返回一个新张量,其中包含 input 元素对应的自然对数。

log10

返回一个新张量,其中包含input元素的以10为底的对数。

log1p

返回一个新张量,其中包含(1 + input)的自然对数。

log2

返回一个新张量,其中包含input元素的以2为底的对数。

logaddexp

输入指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

logical_and

计算给定输入张量的逐元素逻辑与。

logical_not

计算给定输入张量的逐元素逻辑非。

logical_or

计算给定输入张量的逐元素逻辑或。

logical_xor

计算给定输入张量的逐元素逻辑异或。

logit

torch.special.logit()的别名。

hypot

给定直角三角形的两个直角边,返回其斜边。

i0

torch.special.i0()的别名。

igamma

torch.special.gammainc()的别名。

igammac

torch.special.gammaincc()的别名。

mul

input乘以other

multiply

torch.mul()的别名。

mvlgamma

torch.special.multigammaln()的别名。

nan_to_num

nanposinfneginf指定的值分别替换input中的NaN、正无穷和负无穷值。

neg

返回一个新张量,其中包含input元素的负值。

negative

torch.neg()的别名

nextafter

逐元素地返回input朝向other方向的下一个浮点值。

polygamma

torch.special.polygamma()的别名。

positive

返回input

pow

input中的每个元素进行exponent次方运算,并返回结果张量。

quantized_batch_norm

对4D(NCHW)量化张量应用批量归一化。

quantized_max_pool1d

对由多个输入平面组成的输入量化张量应用1D最大池化。

quantized_max_pool2d

对由多个输入平面组成的输入量化张量应用2D最大池化。

rad2deg

返回一个新张量,其中input的每个元素都从角度(弧度)转换为度。

real

返回一个新张量,其中包含self张量的实数值。

reciprocal

返回一个新张量,其中包含input元素的倒数。

remainder

逐条计算Python的模运算

round

input的元素四舍五入到最近的整数。

rsqrt

返回一个新张量,其中包含input每个元素的平方根的倒数。

sigmoid

torch.special.expit()的别名。

sign

返回一个新张量,其中包含input元素的符号。

sgn

此函数是torch.sign()到复数张量的扩展。

signbit

测试input的每个元素的符号位是否已设置。

sin

返回一个新张量,其中包含input元素的正弦。

sinc

torch.special.sinc()的别名。

sinh

返回一个新张量,其中包含input元素的双曲正弦。

softmax

torch.nn.functional.softmax()的别名。

sqrt

返回一个新张量,其中包含 input 元素的平方根。

square

返回一个新张量,其中包含input元素的平方。

sub

input中减去other(按alpha缩放)。

subtract

torch.sub()的别名。

tan

返回一个新张量,其中包含input元素的正切。

tanh

返回一个新张量,其元素是 input 的双曲正切值。

true_divide

torch.div()的别名,其中rounding_mode=None

trunc

返回一个新张量,其中包含input元素的截断整数值。

xlogy

torch.special.xlogy()的别名。

归约操作#

argmax

返回 input 张量中所有元素最大值的索引。

argmin

返回扁平张量或沿某一维度的最小值的索引

amax

返回input张量在给定维度dim中每个切片的最大值。

amin

返回input张量在给定维度dim中每个切片的最小值。

aminmax

计算input张量的最小值和最大值。

all

测试 input 中的所有元素是否都评估为 True

any

测试input中是否有任何元素求值为True

max

返回input张量中所有元素的最大值。

min

返回input张量中所有元素的最小值。

dist

返回(input - other)的p范数

logsumexp

返回input张量在给定维度dim中每一行指数和的对数。

mean

nanmean

计算沿指定维度的所有非NaN元素的均值。

median

返回input中的中位数。

nanmedian

返回input中的中位数,忽略NaN值。

mode

返回一个命名元组(values, indices),其中valuesinput张量在给定维度dim中每一行的众数,即该行中出现次数最多的值,而indices是找到的每个众数值的索引位置。

norm

返回给定张量的矩阵范数或向量范数。

nansum

返回所有元素的总和,将非数字(NaN)视为零。

prod

返回input张量中所有元素的乘积。

quantile

计算input张量沿维度dim每行的q分位数。

nanquantile

这是torch.quantile()的一个变体,它“忽略”NaN值,计算分位数q,就像input中不存在NaN值一样。

std

计算由dim指定的维度的标准差。

std_mean

计算由dim指定的维度的标准差和均值。

sum

返回 input 张量中所有元素的和。

unique

返回输入张量的唯一元素。

unique_consecutive

从每个连续的等效元素组中删除除第一个元素之外的所有元素。

var

计算由dim指定的维度的方差。

var_mean

计算由dim指定的维度的方差和均值。

count_nonzero

计算张量input沿给定dim的非零值的数量。

比较操作#

allclose

此函数检查inputother是否满足条件

argsort

返回按值升序对张量沿给定维度进行排序的索引。

eq

计算逐元素相等

equal

如果两个张量具有相同的大小和元素,则为True,否则为False

ge

逐元素计算inputother\text{input} \geq \text{other}

greater_equal

torch.ge()的别名。

gt

逐元素计算input>other\text{input} > \text{other}

greater

torch.gt()的别名。

isclose

返回一个新张量,其中包含布尔元素,表示input的每个元素是否“接近”other的相应元素。

isfinite

返回一个新张量,其中包含布尔元素,表示每个元素是否为有限

isin

测试elements的每个元素是否在test_elements中。

isinf

测试input的每个元素是否为无穷大(正无穷大或负无穷大)。

isposinf

测试input的每个元素是否为正无穷大。

isneginf

测试input的每个元素是否为负无穷大。

isnan

返回一个新张量,其中包含布尔元素,表示input的每个元素是否为NaN。

isreal

返回一个新张量,其中包含布尔元素,表示input的每个元素是否为实数值。

kthvalue

返回一个命名元组(values, indices),其中valuesinput张量在给定维度dim中每行的第k个最小元素。

le

逐元素计算inputother\text{input} \leq \text{other}

less_equal

torch.le() 的别名。

lt

逐元素计算 input<other\text{input} < \text{other}

less

torch.lt() 的别名。

maximum

计算 inputother 的逐元素最大值。

minimum

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

fmin

计算 inputother 的逐元素最小值。

ne

逐元素计算 inputother\text{input} \neq \text{other}

not_equal

torch.ne() 的别名。

sort

沿给定维度按值升序对 input 张量的元素进行排序。

topk

返回给定 input 张量沿给定维度的 k 个最大元素。

msort

沿 input 张量的第一维按值升序对其元素进行排序。

频谱操作#

stft

短时傅里叶变换 (STFT)。

istft

逆短时傅里叶变换。

bartlett_window

Bartlett 窗函数。

blackman_window

Blackman 窗函数。

hamming_window

Hamming 窗函数。

hann_window

Hann 窗函数。

kaiser_window

计算窗口长度为 window_length 且形状参数为 beta 的 Kaiser 窗。

其他操作#

atleast_1d

返回每个零维度输入张量的一维视图。

atleast_2d

返回每个零维度输入张量的二维视图。

atleast_3d

返回每个零维度输入张量的三维视图。

bincount

计算非负整数数组中每个值的频率。

block_diag

从提供的张量创建块对角矩阵。

broadcast_tensors

根据广播语义广播给定张量。

broadcast_to

input 广播到 shape 形状。

broadcast_shapes

broadcast_tensors() 类似,但用于形状。

bucketize

返回 input 中每个值所属桶的索引,其中桶的边界由 boundaries 设置。

cartesian_prod

对给定张量序列进行笛卡尔积。

cdist

计算两组行向量之间逐批的 p-范数距离。

clone

返回 input 的副本。

combinations

计算给定张量的长度为 rr 的组合。

corrcoef

估计由 input 矩阵给出的变量的皮尔逊积矩相关系数矩阵,其中行是变量,列是观测值。

cov

估计由 input 矩阵给出的变量的协方差矩阵,其中行是变量,列是观测值。

cross

返回 inputother 在维度 dim 上的向量叉积。

cummax

返回命名元组 (values, indices),其中 valuesinput 在维度 dim 上的累积最大值。

cummin

返回命名元组 (values, indices),其中 valuesinput 在维度 dim 上的累积最小值。

cumprod

返回 input 在维度 dim 上的累积乘积。

cumsum

返回 input 在维度 dim 上的累积和。

diag

  • 如果 input 是向量(一维张量),则返回二维方阵。

diag_embed

创建一个张量,其某些二维平面(由 dim1dim2 指定)的对角线由 input 填充。

diagflat

  • 如果 input 是向量(一维张量),则返回二维方阵。

diagonal

返回 input 的部分视图,其中其相对于 dim1dim2 的对角线元素作为维度附加到形状的末尾。

diff

计算沿给定维度的 n 阶前向差分。

einsum

根据爱因斯坦求和约定中指定的符号,对输入 operands 的元素乘积沿维度求和。

flatten

通过将其重塑为一维张量来展平 input

flip

沿 dims 中的给定轴反转 n 维张量的顺序。

fliplr

沿左右方向翻转张量,返回新张量。

flipud

沿上下方向翻转张量,返回新张量。

kron

计算 inputother 的 Kronecker 积,记为 \otimes

rot90

在 dims 轴指定的平面中将 n 维张量旋转 90 度。

gcd

计算 inputother 的逐元素最大公约数 (GCD)。

histc

计算张量的直方图。

histogram

计算张量中值的直方图。

histogramdd

计算张量中值多维直方图。

meshgrid

创建由 attr:tensors 中一维输入指定的坐标网格。

lcm

计算 inputother 的逐元素最小公倍数 (LCM)。

logcumsumexp

返回 input 在维度 dim 上元素指数的累积和的对数。

ravel

返回一个连续的扁平张量。

renorm

返回一个张量,其中 input 沿维度 dim 的每个子张量都经过归一化,使得子张量的 p-范数低于 maxnorm 的值。

repeat_interleave

重复张量的元素。

roll

沿给定维度滚动张量 input

searchsorted

sorted_sequence 的*最内层*维度中找到索引,使得如果将 values 中的相应值插入到这些索引之前,则在排序时,sorted_sequence 中相应*最内层*维度的顺序将保持不变。

tensordot

返回 a 和 b 在多个维度上的收缩。

trace

返回输入二维矩阵对角线元素的和。

tril

返回矩阵(二维张量)或矩阵批次 input 的下三角部分,结果张量 out 的其他元素设置为 0。

tril_indices

返回 row 乘以 col 矩阵的下三角部分的索引,结果是一个 2 乘以 N 的张量,其中第一行包含所有索引的行坐标,第二行包含列坐标。

triu

返回矩阵(二维张量)或矩阵批次 input 的上三角部分,结果张量 out 的其他元素设置为 0。

triu_indices

返回 row 乘以 col 矩阵的上三角部分的索引,结果是一个 2 乘以 N 的张量,其中第一行包含所有索引的行坐标,第二行包含列坐标。

unflatten

将输入张量的一个维度展开为多个维度。

vander

生成 Vandermonde 矩阵。

view_as_real

返回 input 作为实张量的视图。

view_as_complex

返回 input 作为复张量的视图。

resolve_conj

如果 input 的共轭位设置为 True,则返回具有具体化共轭的新张量;否则返回 input

resolve_neg

如果 input 的负位设置为 True,则返回具有具体化求反的新张量;否则返回 input

BLAS 和 LAPACK 操作#

addbmm

执行存储在 batch1batch2 中的矩阵的批量矩阵-矩阵乘积,并减少加法步骤(所有矩阵乘法沿第一维累加)。

addmm

执行矩阵 mat1mat2 的矩阵乘法。

addmv

执行矩阵 mat 和向量 vec 的矩阵-向量乘积。

addr

执行向量 vec1vec2 的外积,并将其添加到矩阵 input 中。

baddbmm

执行 batch1batch2 中矩阵的批量矩阵-矩阵乘积。

bmm

执行存储在 inputmat2 中的矩阵的批量矩阵-矩阵乘积。

chain_matmul

返回 NN 个二维张量的矩阵乘积。

cholesky

计算对称正定矩阵 AA 或对称正定矩阵批次的 Cholesky 分解。

cholesky_inverse

计算给定其 Cholesky 分解的复 Hermitian 或实对称正定矩阵的逆。

cholesky_solve

计算给定其 Cholesky 分解的复 Hermitian 或实对称正定左侧矩阵的线性方程组的解。

dot

计算两个一维张量的点积。

geqrf

这是一个直接调用 LAPACK 的 geqrf 的低级函数。

ger

torch.outer() 的别名。

inner

计算一维张量的点积。

inverse

torch.linalg.inv() 的别名。

det

torch.linalg.det() 的别名。

logdet

计算方阵或方阵批次的对数行列式。

slogdet

torch.linalg.slogdet() 的别名。

lu

计算矩阵或矩阵批次 A 的 LU 分解。

lu_solve

使用 lu_factor() 的 A 的部分主元 LU 分解,返回线性系统 Ax=bAx = b 的 LU 解。

lu_unpack

lu_factor() 返回的 LU 分解解包为 P, L, U 矩阵。

matmul

两个张量的矩阵乘积。

matrix_power

torch.linalg.matrix_power() 的别名。

matrix_exp

torch.linalg.matrix_exp() 的别名。

mm

执行矩阵 inputmat2 的矩阵乘法。

mv

执行矩阵 input 和向量 vec 的矩阵-向量乘积。

orgqr

torch.linalg.householder_product() 的别名。

ormqr

计算 Householder 矩阵乘积与一般矩阵的矩阵-矩阵乘法。

outer

inputvec2 的外积。

pinverse

torch.linalg.pinv() 的别名。

qr

计算矩阵或矩阵批次 input 的 QR 分解,并返回张量的命名元组 (Q, R),使得 input=QR\text{input} = Q R,其中 QQ 是正交矩阵或正交矩阵批次,RR 是上三角矩阵或上三角矩阵批次。

svd

计算矩阵或矩阵批次 input 的奇异值分解。

svd_lowrank

返回矩阵、矩阵批次或稀疏矩阵 AA 的奇异值分解 (U, S, V),使得 AUdiag(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}

pca_lowrank

对低秩矩阵、这类矩阵的批次或稀疏矩阵执行线性主成分分析 (PCA)。

lobpcg

使用无矩阵 LOBPCG 方法,找到对称正定广义特征值问题的 k 个最大(或最小)特征值和相应的特征向量。

trapz

torch.trapezoid() 的别名。

trapezoid

沿 dim 计算 梯形法则

cumulative_trapezoid

沿 dim 累积计算 梯形法则

triangular_solve

求解具有方阵上三角或下三角可逆矩阵 AA 和多个右侧 bb 的方程组。

vdot

计算两个一维向量沿维度的点积。

Foreach 操作#

警告

此 API 处于 Beta 阶段,未来可能会有兼容性破坏性更改。不支持前向模式 AD。

_foreach_abs

对输入列表的每个张量应用 torch.abs()

_foreach_abs_

对输入列表的每个张量应用 torch.abs()

_foreach_acos

对输入列表的每个张量应用 torch.acos()

_foreach_acos_

对输入列表的每个张量应用 torch.acos()

_foreach_asin

对输入列表的每个张量应用 torch.asin()

_foreach_asin_

对输入列表的每个张量应用 torch.asin()

_foreach_atan

对输入列表的每个张量应用 torch.atan()

_foreach_atan_

对输入列表的每个张量应用 torch.atan()

_foreach_ceil

对输入列表的每个张量应用 torch.ceil()

_foreach_ceil_

对输入列表的每个张量应用 torch.ceil()

_foreach_cos

对输入列表的每个张量应用 torch.cos()

_foreach_cos_

对输入列表的每个张量应用 torch.cos()

_foreach_cosh

对输入列表的每个张量应用 torch.cosh()

_foreach_cosh_

对输入列表的每个张量应用 torch.cosh()

_foreach_erf

对输入列表的每个张量应用 torch.erf()

_foreach_erf_

对输入列表的每个张量应用 torch.erf()

_foreach_erfc

对输入列表的每个张量应用 torch.erfc()

_foreach_erfc_

对输入列表的每个张量应用 torch.erfc()

_foreach_exp

对输入列表的每个张量应用 torch.exp()

_foreach_exp_

对输入列表的每个张量应用 torch.exp()

_foreach_expm1

对输入列表的每个张量应用 torch.expm1()

_foreach_expm1_

对输入列表的每个张量应用 torch.expm1()

_foreach_floor

对输入列表的每个张量应用 torch.floor()

_foreach_floor_

对输入列表的每个张量应用 torch.floor()

_foreach_log

对输入列表的每个张量应用 torch.log()

_foreach_log_

对输入列表的每个张量应用 torch.log()

_foreach_log10

对输入列表的每个张量应用 torch.log10()

_foreach_log10_

对输入列表的每个张量应用 torch.log10()

_foreach_log1p

对输入列表的每个张量应用 torch.log1p()

_foreach_log1p_

对输入列表的每个张量应用 torch.log1p()

_foreach_log2

对输入列表的每个张量应用 torch.log2()

_foreach_log2_

对输入列表的每个张量应用 torch.log2()

_foreach_neg

对输入列表的每个张量应用 torch.neg()

_foreach_neg_

对输入列表的每个张量应用 torch.neg()

_foreach_tan

对输入列表的每个张量应用 torch.tan()

_foreach_tan_

对输入列表的每个张量应用 torch.tan()

_foreach_sin

对输入列表的每个张量应用 torch.sin()

_foreach_sin_

对输入列表的每个张量应用 torch.sin()

_foreach_sinh

对输入列表的每个张量应用 torch.sinh()

_foreach_sinh_

对输入列表的每个张量应用 torch.sinh()

_foreach_round

对输入列表的每个张量应用 torch.round()

_foreach_round_

对输入列表的每个张量应用 torch.round()

_foreach_sqrt

对输入列表的每个张量应用 torch.sqrt()

_foreach_sqrt_

对输入列表的每个张量应用 torch.sqrt()

_foreach_lgamma

对输入列表的每个张量应用 torch.lgamma()

_foreach_lgamma_

对输入列表的每个张量应用 torch.lgamma()

_foreach_frac

对输入列表的每个张量应用 torch.frac()

_foreach_frac_

对输入列表的每个张量应用 torch.frac()

_foreach_reciprocal

对输入列表的每个张量应用 torch.reciprocal()

_foreach_reciprocal_

对输入列表的每个张量应用 torch.reciprocal()

_foreach_sigmoid

对输入列表的每个张量应用 torch.sigmoid()

_foreach_sigmoid_

对输入列表的每个张量应用 torch.sigmoid()

_foreach_trunc

对输入列表的每个张量应用 torch.trunc()

_foreach_trunc_

对输入列表的每个张量应用 torch.trunc()

_foreach_zero_

对输入列表的每个张量应用 torch.zero()

实用工具#

compiled_with_cxx11_abi

返回 PyTorch 是否使用 _GLIBCXX_USE_CXX11_ABI=1 构建。

result_type

返回对所提供的输入张量执行算术运算后将产生的 torch.dtype

can_cast

确定根据类型提升文档中描述的 PyTorch 类型转换规则,是否允许类型转换。

promote_types

返回 torch.dtype,其大小和标量类型不小于 type1type2,且为两者中最小的。

use_deterministic_algorithms

设置 PyTorch 操作是否必须使用“确定性”算法。

are_deterministic_algorithms_enabled

如果全局确定性标志已启用,则返回 True。

is_deterministic_algorithms_warn_only_enabled

如果全局确定性标志设置为仅警告模式,则返回 True。

set_deterministic_debug_mode

设置确定性操作的调试模式。

get_deterministic_debug_mode

返回确定性操作调试模式的当前值。

set_float32_matmul_precision

设置 float32 矩阵乘法的内部精度。

get_float32_matmul_precision

返回 float32 矩阵乘法精度的当前值。

set_warn_always

当此标志为 False(默认)时,某些 PyTorch 警告可能在每个进程中只出现一次。

get_device_module

返回与给定设备相关的模块(例如,torch.device('cuda')、"mtia:0"、"xpu" 等)。

is_warn_always_enabled

如果全局 warn_always 标志已启用,则返回 True。

vmap

vmap 是矢量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的一些维度上。

_assert

Python 的 assert 的包装器,可进行符号跟踪。

符号数#

class torch.SymInt(node)[source]#

像一个 int(包括魔法方法),但将所有操作重定向到包装的节点。这尤其用于在符号形状工作流程中符号式记录操作。

as_integer_ratio()[source]#

将此整数表示为精确整数比

返回类型

tuple[‘SymInt’, int]

class torch.SymFloat(node)[source]#

像一个浮点数(包括魔法方法),但将所有操作重定向到包装的节点。这尤其用于在符号形状工作流程中符号式记录操作。

as_integer_ratio()[source]#

将此浮点数表示为精确整数比

返回类型

tuple[int, int]

conjugate()[source]#

返回浮点数的复共轭。

返回类型

SymFloat

hex()[source]#

返回浮点数的十六进制表示。

返回类型

str

is_integer()[source]#

如果浮点数是整数,则返回 True。

class torch.SymBool(node)[source]#

像一个布尔值(包括魔法方法),但将所有操作重定向到包装的节点。这尤其用于在符号形状工作流程中符号式记录操作。

与常规布尔值不同,常规布尔运算符将强制额外的防护,而不是符号化评估。请改用按位运算符来处理此问题。

sym_float

SymInt 感知的浮点转换实用工具。

sym_fresh_size

sym_int

SymInt 感知的整数转换实用工具。

sym_max

SymInt 感知的最大值实用工具,可避免 a < b 的分支。

sym_min

SymInt 感知的最小值实用工具。

sym_not

SymInt 感知的逻辑非实用工具。

sym_ite

SymInt 感知的条件运算符实用工具(t if b else f)。

sym_sum

N 元加法,对于长列表比迭代二进制加法计算更快。

导出路径#

警告

此功能是原型,未来可能会有兼容性破坏性更改。

export generated/exportdb/index

控制流#

警告

此功能是原型,未来可能会有兼容性破坏性更改。

cond

有条件地应用 true_fnfalse_fn

优化#

compile

使用 TorchDynamo 和指定的后端优化给定模型/函数。

torch.compile 文档

运算符标签#

class torch.Tag#

成员

核心

cudagraph_unsafe

data_dependent_output

dynamic_output_shape

flexible_layout

generated

inplace_view

maybe_aliasing_or_mutating

needs_contiguous_strides

needs_exact_strides

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy

property name#