convert_to_float8_training¶
- torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module [源代码]¶
将 module 中的 torch.nn.Linear 替换为 Float8Linear。
- 参数:
module – 要修改的模块。
module_filter_fn – 如果指定,只有通过过滤函数的 torch.nn.Linear 子类才会被替换。过滤函数的输入是模块实例和 FQN。
config (Float8LinearConfig) – 转换为 float8 的配置
- 返回:
已修改的模块,其中线性层已被替换。
- 返回类型:
nn.Module