torch.Tensor.to#
- Tensor.to(*args, **kwargs) Tensor#
执行 Tensor 的 dtype 和/或 device 转换。
torch.dtype和torch.device是从self.to(*args, **kwargs)的参数推断出来的。注意
如果
selfTensor 已经具有正确的torch.dtype和torch.device,则返回self。否则,返回的 Tensor 是self的副本,具有所需的torch.dtype和torch.device。注意
如果
self需要梯度(requires_grad=True),但指定的dtype是整数类型,则返回的 Tensor 将隐式设置requires_grad=False。这是因为只有具有浮点数或复数 dtype 的 Tensor 才能需要梯度。以下是调用
to的方式:- to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor
返回一个具有指定
dtype的 Tensor。- 参数 (Args)
memory_format (
torch.memory_format, optional): 返回的 Tensor 的期望内存格式。默认值:torch.preserve_format。
注意
根据 C++ 类型转换规则,将浮点值转换为整数类型时会截断小数部分。如果截断后的值无法装入目标类型(例如,将
torch.inf转换为torch.long),则行为未定义,结果可能因平台而异。- torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor
返回一个具有指定
device和(可选的)dtype的 Tensor。如果dtype为None,则推断为self.dtype。当non_blocking设置为True时,该函数会尝试与主机进行异步转换(如果可能)。这种异步行为适用于固定内存和分页内存。但使用此功能时需谨慎。有关更多信息,请参阅 关于 non_blocking 和 pin_memory 的最佳实践教程。当copy设置为True时,即使 Tensor 已满足期望的转换,也会创建一个新 Tensor。- 参数 (Args)
memory_format (
torch.memory_format, optional): 返回的 Tensor 的期望内存格式。默认值:torch.preserve_format。
- torch.to(other, non_blocking=False, copy=False) Tensor
返回一个与 Tensor
other具有相同torch.dtype和torch.device的 Tensor。当non_blocking设置为True时,该函数会尝试与主机进行异步转换(如果可能)。这种异步行为适用于固定内存和分页内存。但使用此功能时需谨慎。有关更多信息,请参阅 关于 non_blocking 和 pin_memory 的最佳实践教程。当copy设置为True时,即使 Tensor 已满足期望的转换,也会创建一个新 Tensor。
示例
>>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu >>> tensor.to(torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64) >>> cuda0 = torch.device('cuda:0') >>> tensor.to(cuda0) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], device='cuda:0') >>> tensor.to(cuda0, dtype=torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') >>> other = torch.randn((), dtype=torch.float64, device=cuda0) >>> tensor.to(other, non_blocking=True) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')