torch.set_default_dtype#
- torch.set_default_dtype(d, /)[源代码]#
将默认浮点数据类型设置为
d
。支持浮点数据类型作为输入。其他数据类型将导致torch引发异常。当PyTorch初始化时,其默认浮点数据类型为torch.float32,而set_default_dtype(torch.float64)的目的是为了便于NumPy风格的类型推断。默认浮点数据类型用于
隐式确定默认复数数据类型。当默认浮点类型为float16时,默认复数数据类型为complex32。对于float32,默认复数数据类型为complex64。对于float64,它是complex128。对于bfloat16,将引发异常,因为bfloat16没有对应的复数类型。
推断使用Python浮点数或复数构建的张量的数据类型。请参阅下面的示例。
确定布尔值和整数张量与Python浮点数和复数之间的类型提升结果。
- 参数
d (
torch.dtype
) – 要设置为默认值的浮点数据类型。
示例
>>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64
>>> torch.set_default_dtype(torch.float64) >>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128
>>> torch.set_default_dtype(torch.float16) >>> # Python floats are now interpreted as float16 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float16 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex32