快捷方式

convert_to_float8_training

torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module[源代码]

module 中的 torch.nn.Linear 替换为 Float8Linear

参数:
  • module – 要修改的模块。

  • module_filter_fn – 如果指定,只有通过过滤函数的 torch.nn.Linear 子类才会被替换。过滤函数的输入是模块实例和 FQN。

  • config (Float8LinearConfig) – 转换为 float8 的配置

返回:

已修改的模块,其中线性层已被替换。

返回类型:

nn.Module

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源