评价此页

大规模部署功能#

创建于: 2019年7月24日 | 最后更新于: 2025年7月15日

本文档讨论了在 PyTorch 运行于更大的系统内或在大型组织中使用 PyTorch 运行多个系统时可能有用的一些扩展点和技巧。

本文档假设您要么从源代码构建 PyTorch,要么能够静态链接额外的代码,以便在 PyTorch 使用时加载。因此,许多钩子都作为 C++ API 公开,可以在中心化位置(例如,静态初始化代码)一次性触发。

全集群算子分析#

PyTorch 提供了 torch.autograd.profiler,能够按需测量单个算子所花费的时间。可以使用相同的机制对运行 PyTorch 的任何进程进行“始终开启”的测量。这可能有助于收集给定进程或整个机器集上 PyTorch 工作负载的信息。

可以使用 torch::addGlobalCallback 为任何算子调用添加新的回调。钩子将使用 torch::RecordFunction 结构体进行调用,该结构体描述了调用上下文(例如,name)。如果启用了输入日志记录,RecordFunction::inputs() 将包含表示为 torch::IValue 变体类型的函数参数。请注意,输入日志记录成本相对较高,因此必须显式启用。

算子回调还可以访问 c10::ThreadLocalDebugInfo::get() 接口,该接口返回一个指向保存调试信息的结构体的指针。此调试信息可以通过使用 at::DebugInfoGuard 对象提前设置。调试信息会通过前向(包括异步 fork 任务)和后向传播,并可用于将有关执行环境的额外信息(例如,模型 ID)从应用程序的高层传递到底层算子回调。

调用回调会增加一些开销,因此通常最好仅随机抽样算子调用。这可以通过在 torch::addGlobalCallback 中传递可选的采样率来按回调进行启用。

请注意,addGlobalCallback 不是线程安全的,只能在没有 PyTorch 算子运行时调用。通常,在初始化期间调用它们一次是个好主意。

示例如下:

// Called somewhere in the program beginning
void init() {
    // Sample one in a hundred operator runs randomly
    addGlobalCallback(
      RecordFunctionCallback(
        &onFunctionEnter,
        &onFunctionExit)
      .needsInputs(true)
      .samplingProb(0.01)
    );
    // Note, to enable observers in the model calling thread,
    // call enableRecordFunction() in the thread before running a model
}

void onFunctionEnter(const RecordFunction& fn) {
    std::cerr << "Before function " << fn.name()
              << " with " << fn.inputs().size() << " inputs" << std::endl;
}

void onFunctionExit(const RecordFunction& fn) {
    std::cerr << "After function " << fn.name();
}

API 使用日志记录#

在更广泛的生态系统中运行时,例如在托管作业调度器中,跟踪哪些二进制文件调用了特定的 PyTorch API 通常很有用。在几个重要的 API 点注入了一个简单的插桩,用于触发给定的回调。由于 PyTorch 通常在一次性 Python 脚本中调用,因此对于每个 API,回调在给定进程中最多只触发一次。

可以使用 c10::SetAPIUsageHandler 注册 API 使用插桩处理程序。传递的参数将是一个“api key”,用于标识使用的点,例如,用于 PyTorch 扩展导入的 python.import

SetAPIUsageLogger([](const std::string& event_name) {
    std::cerr << "API was used: " << event_name << std::endl;
});

给开发者的说明:可以使用 C++ 中的 C10_LOG_API_USAGE_ONCE("my_api") 或 Python 中的 torch._C._log_api_usage_once("my.api") 在代码中添加新的 API 触发点。

通用扩展点#

PyTorch API 通常是松散耦合的,很容易用专用版本替换某个组件。通用扩展点包括:

  • 用 C++ 实现的自定义算子 - 有关更多详细信息,请参阅教程

  • 自定义数据读取通常可以通过调用相应的 Python 库直接集成。通过扩展 DatasetIterableDataset,可以利用 torch.utils.data 的现有功能。