评价此页

torch.cuda.comm.gather#

torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source]#

从多个 GPU 设备收集张量。

参数
  • tensors (Iterable[Tensor]) – 要收集的张量可迭代对象。除 dim 以外的所有维度的张量大小必须匹配。

  • dim (int, optional) – 张量将沿此维度连接。默认值:0

  • destination (torch.device, str, or int, optional) – 输出设备。可以是 CPU 或 CUDA。默认值:当前 CUDA 设备。

  • out (Tensor, optional, keyword-only) – 用于存储收集结果的张量。其大小必须与 tensors 的大小匹配,除了 dim 维度,该维度的大小必须等于 sum(tensor.size(dim) for tensor in tensors)。可以位于 CPU 或 CUDA 上。

注意

destination 不能与 out 同时指定。

返回

  • 如果指定了 destination,则

    一个位于 destination 设备的张量,它是将 tensors 沿 dim 连接的结果。

  • 如果指定了 out

    包含 tensors 沿 dim 连接结果的 out 张量。