评价此页

torch.set_default_dtype#

torch.set_default_dtype(d, /)[source]#

将默认浮点 dtype 设置为 d。支持浮点 dtype 作为输入。其他 dtype 将导致 torch 引发异常。

PyTorch 初始化时,其默认浮点 dtype 为 torch.float32,set_default_dtype(torch.float64) 的目的是为了方便 NumPy 类似的类型推断。默认浮点 dtype 用于

  1. 隐式确定默认复数 dtype。当默认浮点类型为 float16 时,默认复数 dtype 为 complex32。对于 float32,默认复数 dtype 为 complex64。对于 float64,它为 complex128。对于 bfloat16,将引发异常,因为 bfloat16 没有对应的复数类型。

  2. 推断使用 Python 浮点数或复数 Python 数字构造的张量的 dtype。参见下面的示例。

  3. 确定布尔型和整型张量与 Python 浮点数和复数 Python 数字之间类型提升的结果。

参数

d (torch.dtype) – 要设置为默认的浮点 dtype。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32