Float8LinearConfig¶
- class torchao.float8.Float8LinearConfig(cast_config_input: ~torchao.float8.config.CastConfig = CastConfig(scaling_type=<ScalingType.DYNAMIC: 'dynamic'>, scaling_granularity=<ScalingGranularity.TENSORWISE: 'tensorwise'>, target_dtype=None), cast_config_input_for_grad_weight: ~typing.Optional[~torchao.float8.config.CastConfig] = None, cast_config_weight: ~torchao.float8.config.CastConfig = CastConfig(scaling_type=<ScalingType.DYNAMIC: 'dynamic'>, scaling_granularity=<ScalingGranularity.TENSORWISE: 'tensorwise'>, target_dtype=None), cast_config_weight_for_grad_input: ~typing.Optional[~torchao.float8.config.CastConfig] = None, cast_config_grad_output: ~torchao.float8.config.CastConfig = CastConfig(scaling_type=<ScalingType.DYNAMIC: 'dynamic'>, scaling_granularity=<ScalingGranularity.TENSORWISE: 'tensorwise'>, target_dtype=None), cast_config_grad_output_for_grad_weight: ~typing.Optional[~torchao.float8.config.CastConfig] = None, gemm_config_output: ~torchao.float8.config.Float8GemmConfig = Float8GemmConfig(use_fast_accum=True), gemm_config_grad_input: ~torchao.float8.config.Float8GemmConfig = Float8GemmConfig(use_fast_accum=False), gemm_config_grad_weight: ~torchao.float8.config.Float8GemmConfig = Float8GemmConfig(use_fast_accum=False), enable_fsdp_float8_all_gather: bool = False, pad_inner_dim: bool = False, emulate: bool = False, force_recompute_fp8_weight_in_bwd: bool = False, round_scales_to_power_of_2: bool = False)[源代码]¶
配置将 torch.nn.Linear 模块转换为 float8 以进行训练。
- static from_recipe_name(recipe_name: Union[Float8LinearRecipeName, str]) Float8LinearConfig [源代码]¶
输入: Float8LinearRecipeName 值,或表示 Float8LinearRecipeName 值的字符串 输出: 一个 Float8LinearConfig,配置为实现指定的配方