torch.cuda.comm.gather#
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source]#
从多个 GPU 设备收集张量。
- 参数
tensors (Iterable[Tensor]) – 要收集的张量可迭代对象。除
dim
以外的所有维度的张量大小必须匹配。dim (int, optional) – 张量将沿此维度连接。默认值:
0
。destination (torch.device, str, or int, optional) – 输出设备。可以是 CPU 或 CUDA。默认值:当前 CUDA 设备。
out (Tensor, optional, keyword-only) – 用于存储收集结果的张量。其大小必须与
tensors
的大小匹配,除了dim
维度,该维度的大小必须等于sum(tensor.size(dim) for tensor in tensors)
。可以位于 CPU 或 CUDA 上。
注意
destination
不能与out
同时指定。- 返回
- 如果指定了
destination
,则 一个位于
destination
设备的张量,它是将tensors
沿dim
连接的结果。
- 如果指定了
- 如果指定了
out
, 包含
tensors
沿dim
连接结果的out
张量。
- 如果指定了