torch.Tensor.to#
- Tensor.to(*args, **kwargs) Tensor #
执行 Tensor 的 dtype 和/或设备转换。
self.to(*args, **kwargs)
的参数将推断出torch.dtype
和torch.device
。注意
如果
self
Tensor 已经具有正确的torch.dtype
和torch.device
,则返回self
。否则,返回的 Tensor 是self
的副本,并具有所需的torch.dtype
和torch.device
。注意
如果
self
需要梯度 (requires_grad=True
) 但指定的dtype
是整数类型,则返回的 Tensor 将隐式地将requires_grad
设置为False
。这是因为只有具有浮点数或复数 dtype 的 Tensor 才能要求梯度。以下是调用
to
的方法- to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor
返回具有指定
dtype
的 Tensor- 参数 (Args)
memory_format (
torch.memory_format
, optional): 返回 Tensor 的所需内存格式。默认为torch.preserve_format
。
注意
根据 C++ 类型转换规则,将浮点数值转换为整数类型会截断小数部分。如果截断后的值无法适应目标类型(例如,将
torch.inf
强制转换为torch.long
),则行为是未定义的,并且结果可能因平台而异。- torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor
返回具有指定
device
和(可选)dtype
的 Tensor。如果dtype
为None
,则将其推断为self.dtype
。当non_blocking
设置为True
时,该函数会尝试以异步方式(相对于主机)执行转换,如果可能的话。这种异步行为同时适用于固定内存和页面内存。但是,在使用此功能时应谨慎。有关更多信息,请参阅 有关正确使用 non_blocking 和 pin_memory 的教程。当copy
设置为True
时,即使 Tensor 已匹配所需的转换,也会创建一个新的 Tensor。- 参数 (Args)
memory_format (
torch.memory_format
, optional): 返回 Tensor 的所需内存格式。默认为torch.preserve_format
。
- torch.to(other, non_blocking=False, copy=False) Tensor
返回与 Tensor
other
具有相同torch.dtype
和torch.device
的 Tensor。当non_blocking
设置为True
时,该函数会尝试以异步方式(相对于主机)执行转换,如果可能的话。这种异步行为同时适用于固定内存和页面内存。但是,在使用此功能时应谨慎。有关更多信息,请参阅 有关正确使用 non_blocking 和 pin_memory 的教程。当copy
设置为True
时,即使 Tensor 已匹配所需的转换,也会创建一个新的 Tensor。
示例
>>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu >>> tensor.to(torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64) >>> cuda0 = torch.device('cuda:0') >>> tensor.to(cuda0) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], device='cuda:0') >>> tensor.to(cuda0, dtype=torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') >>> other = torch.randn((), dtype=torch.float64, device=cuda0) >>> tensor.to(other, non_blocking=True) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')