From f151dd6ededd6fefd6fd5e245a434e563048fd61 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 31 May 2026 23:20:31 -0700 Subject: [PATCH 01/11] Add FP16 NVFP4 4over6 error modes Signed-off-by: Ziang Li --- docs/envvars.rst | 4 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 14 +- .../nvfp4/test_nvfp4_quantize_exact.py | 24 ++- tests/pytorch/test_recipe.py | 6 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 149 ++++++++++++++---- .../transformer_engine/transformer_engine.h | 8 +- transformer_engine/common/recipe/__init__.py | 6 +- .../common/transformer_engine.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 4 + .../custom_recipes/quantization_ref_nvfp4.py | 137 ++++++++++++---- .../pytorch/tensor/nvfp4_tensor.py | 6 +- 11 files changed, 281 insertions(+), 81 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index bd62ccac46..1c0acb61ff 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -301,9 +301,9 @@ Kernel Configuration .. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE - :Type: ``str`` (``MAE`` or ``MSE``) + :Type: ``str`` (``MAE``, ``MSE``, ``MAE_FP16``, or ``MSE_FP16``) :Default: ``MAE`` - :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. + :Description: Select the error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. ``MAE`` and ``MSE`` compare dequantized candidates in the original input domain. ``MAE_FP16`` and ``MSE_FP16`` compare candidates in the E4M3-scaled domain after the E2M1 x E4M3 product is rounded to FP16. .. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index d6ab4b6740..7a9df47823 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -1165,6 +1165,10 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) name += "XMSE"; } else if (config.mode == kNVTENVFP44Over6MinMAE) { name += "XMAE"; + } else if (config.mode == kNVTENVFP44Over6MinMSEFP16) { + name += "XMSE_FP16"; + } else if (config.mode == kNVTENVFP44Over6MinMAEFP16) { + name += "XMAE_FP16"; } else { name += "XINVALID_MODE"; } @@ -1219,10 +1223,18 @@ INSTANTIATE_TEST_SUITE_P( NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 448, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, false}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 448, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 448, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 448, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 448, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, false}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, true})), // four_over_six_config + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 256, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 256, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 256, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 256, true})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 5bb92f70dc..5a6f0d104b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -199,7 +199,11 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], + ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], +) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -243,7 +247,11 @@ def test_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], + ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], +) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -360,7 +368,11 @@ def test_nvfp4_quantization_extrema_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], + ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], +) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -490,7 +502,11 @@ def test_nvfp4_quantization_boundary_values( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], + ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], +) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9a14cee7fd..ecc5294f36 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -525,7 +525,11 @@ def test_quantizer_update(self, module_class): ["none", "weights", "activations", "all"], ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], ) -@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], + ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], +) def test_nvfp4_row_scaled_quantizer_roles( nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode ): diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index b6057370dc..e00f43b696 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -39,19 +39,27 @@ namespace nvfp4 { #if FP4_TYPE_SUPPORTED -#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \ - switch (MODE) { \ - case kNVTENVFP44Over6MinMAE: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \ - { __VA_ARGS__ } \ - } break; \ - case kNVTENVFP44Over6MinMSE: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \ - } \ +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \ + switch (MODE) { \ + case kNVTENVFP44Over6MinMAE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6MinMSE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6MinMAEFP16: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAEFP16; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6MinMSEFP16: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSEFP16; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \ + } \ } #define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \ @@ -85,6 +93,8 @@ template struct Config { static constexpr NVTENVFP44Over6Mode mode = kMode; static constexpr bool err_use_fast_math = kErrUseFastMath; + static constexpr bool fp16_error = + kMode == kNVTENVFP44Over6MinMAEFP16 || kMode == kNVTENVFP44Over6MinMSEFP16; }; struct Candidate { @@ -102,13 +112,14 @@ struct ScalePair { nvfp4_scale_t map6; float inv_map4; float inv_map6; + float global_encode_scale; }; template __device__ __forceinline__ float compute_error_rn(const float diff) { - if constexpr (kMode == kNVTENVFP44Over6MinMSE) { + if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) { return __fmul_rn(diff, diff); - } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { + } else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) { return fabsf(diff); } else { NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); @@ -118,9 +129,9 @@ __device__ __forceinline__ float compute_error_rn(const float diff) { template __device__ __forceinline__ float compute_error(const float diff) { - if constexpr (kMode == kNVTENVFP44Over6MinMSE) { + if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) { return diff * diff; - } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { + } else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) { return fabsf(diff); } else { NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); @@ -147,6 +158,7 @@ __device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, fminf(1.0f / (static_cast(scales.map4) * S_dec), detail::TypeExtrema::max); scales.inv_map6 = fminf(1.0f / (static_cast(scales.map6) * S_dec), detail::TypeExtrema::max); + scales.global_encode_scale = S_enc; return scales; } @@ -214,12 +226,67 @@ __device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_ } } -template -__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], - const float block_scale_inverse, +__device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { + return *reinterpret_cast(&sf); +} + +__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte, + const nvfp4_scale_t sf) { + float2 result; + const uint32_t sf_byte = static_cast(fp8_bits(sf)); + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + ".reg .b16 fp8_pair;\n" + ".reg .b16 scale_h, unused_h;\n" + ".reg .b16 lo, hi;\n" + ".reg .b32 q_h2;\n" + ".reg .b32 scale_h2;\n" + ".reg .b32 prod_h2;\n" + "mov.b32 {byte0, byte1, byte2, byte3}, %2;\n" + "cvt.rn.f16x2.e2m1x2 q_h2, byte0;\n" + "cvt.u16.u32 fp8_pair, %3;\n" + "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" + "mov.b32 {scale_h, unused_h}, scale_h2;\n" + "mov.b32 scale_h2, {scale_h, scale_h};\n" + "mul.rn.f16x2 prod_h2, q_h2, scale_h2;\n" + "mov.b32 {lo, hi}, prod_h2;\n" + "cvt.f32.f16 %0, lo;\n" + "cvt.f32.f16 %1, hi;\n" + "}" + : "=f"(result.x), "=f"(result.y) + : "r"(e2m1_byte), "r"(sf_byte)); + return result; +} + +template +__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t e2m1_byte, + const float x0, const float x1, const nvfp4_scale_t sf, - const float global_amax, + const float global_encode_scale, float *err) { + const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf); + if constexpr (Cfg::err_use_fast_math) { + const float original0 = x0 * global_encode_scale; + const float original1 = x1 * global_encode_scale; + const float diff0 = candidate.x - original0; + const float diff1 = candidate.y - original1; + *err += compute_error(diff0); + *err += compute_error(diff1); + } else { + const float original0 = __fmul_rn(x0, global_encode_scale); + const float original1 = __fmul_rn(x1, global_encode_scale); + const float diff0 = __fsub_rn(candidate.x, original0); + const float diff1 = __fsub_rn(candidate.y, original1); + *err = __fadd_rn(*err, compute_error_rn(diff0)); + *err = __fadd_rn(*err, compute_error_rn(diff1)); + } +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( + const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf, + const float global_amax, const float global_encode_scale, float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -253,15 +320,25 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (& "Try recompiling with sm_XXXa instead of sm_XXX."); } - const float sf_float = static_cast(sf); - accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); - accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + if constexpr (Cfg::fp16_error) { + accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); + accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, + err); + accumulate_fp16_scaled_error_pair((out >> 16) & 0xFFu, x[4], x[5], sf, global_encode_scale, + err); + accumulate_fp16_scaled_error_pair((out >> 24) & 0xFFu, x[6], x[7], sf, global_encode_scale, + err); + } else { + const float sf_float = static_cast(sf); + accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + } return out; } @@ -273,13 +350,17 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c candidates.map4.err = 0.0f; candidates.map6.err = 0.0f; candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + x0, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, + &candidates.map4.err); candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + x0, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, + &candidates.map6.err); candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + x1, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, + &candidates.map4.err); candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + x1, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, + &candidates.map6.err); return candidates; } diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index f675b2f535..491a1978be 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -122,9 +122,11 @@ enum NVTEScalingMode { * \brief Method for NVFP4 4over6 quantization. */ enum NVTENVFP44Over6Mode { - kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */ - kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */ - kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */ + kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */ + kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */ + kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */ + kNVTENVFP44Over6MinMAEFP16 = 3, /*!< Select with lower absolute error in FP16 domain */ + kNVTENVFP44Over6MinMSEFP16 = 4, /*!< Select with lower squared error in FP16 domain */ }; /*! \brief TE Tensor type diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 8a03f2f51a..ab206320b0 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -14,7 +14,7 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") _NVFP4_4OVER6_SCOPES = ("none", "weights", "activations", "all") -_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE") +_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE", "MAE_FP16", "MSE_FP16") class _FormatHelper(NamedTuple): @@ -535,7 +535,7 @@ class NVFP4BlockScaling(Recipe): Select 4over6 tensors that use 256 as the global E4M3 scale bound. By default, all 4over6 tensors use 256. Use ``'none'`` to keep the standard NVFP4 448 bound for 4over6 tensors. - nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' + nvfp4_4over6_err_mode : {'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'}, default = 'MAE' Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, @@ -577,7 +577,7 @@ def __post_init__(self) -> None: ), "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be one of: 'none', 'weights', 'activations', 'all'." assert ( self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES - ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." + ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b3179d38fd..0035772bf2 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1109,7 +1109,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, const auto val = *reinterpret_cast(buf); NVTE_CHECK(val == static_cast(kNVTENVFP44Over6Disabled) || val == static_cast(kNVTENVFP44Over6MinMAE) || - val == static_cast(kNVTENVFP44Over6MinMSE), + val == static_cast(kNVTENVFP44Over6MinMSE) || + val == static_cast(kNVTENVFP44Over6MinMAEFP16) || + val == static_cast(kNVTENVFP44Over6MinMSEFP16), "Invalid NVFP4 4over6 mode (got ", static_cast(val), ")"); config_.nvfp4_4over6_mode = static_cast(val); break; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bc87b54ba8..d81fc345f8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1740,6 +1740,10 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAE; } else if (nvfp4_4over6_err_mode == "MSE") { this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSE; + } else if (nvfp4_4over6_err_mode == "MAE_FP16") { + this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAEFP16; + } else if (nvfp4_4over6_err_mode == "MSE_FP16") { + this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSEFP16; } else { NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode); } diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 5c23c76703..d724b41183 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -359,8 +359,10 @@ def __init__( with_random_sign_mask: bool = True, ): nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if nvfp4_4over6_err_mode not in ("MAE", "MSE"): - raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") + if nvfp4_4over6_err_mode not in ("MAE", "MSE", "MAE_FP16", "MSE_FP16"): + raise ValueError( + "nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." + ) if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -464,6 +466,63 @@ def _recover_swizzled_scales( result = torch.reshape(tmp, (rounded_m, rounded_n)) return result[:m, :scale_n] + @staticmethod + def _ref_nvfp4_4over6_fp16_candidate(q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Decode E2M1 x E4M3 with the kernel's FP16 product semantics.""" + q_float = q.to(torch.float32) + q_sign = (q_float < 0).to(torch.int32) + q_sig = (torch.abs(q_float) * 2).to(torch.int32) + + scale_code = scale.contiguous().view(torch.uint8).to(torch.int32) + scale_sign = scale_code >> 7 + scale_exp_field = (scale_code >> 3) & 0xF + scale_mantissa = scale_code & 0x7 + scale_sig = torch.where(scale_exp_field == 0, scale_mantissa, scale_mantissa + 8) + scale_exp2 = torch.where(scale_exp_field == 0, scale_exp_field - 9, scale_exp_field - 10) + + product_sign = q_sign ^ scale_sign + product_sig = q_sig * scale_sig + product_exp2 = scale_exp2 - 1 + + log2_sig = torch.zeros_like(product_sig) + for threshold in (2, 4, 8, 16, 32, 64, 128, 256): + log2_sig = log2_sig + (product_sig >= threshold).to(torch.int32) + + floor_exp = log2_sig + product_exp2 + normal_bits = ((floor_exp + 15) << 10) | ( + torch.bitwise_left_shift(product_sig, 10 - log2_sig) - 1024 + ) + subnormal_bits = torch.bitwise_left_shift(product_sig, product_exp2 + 24) + magnitude_bits = torch.where(floor_exp < -14, subnormal_bits, normal_bits) + prod_bits = (product_sign << 15) | magnitude_bits + prod_bits = torch.where(product_sig == 0, product_sign << 15, prod_bits) + prod_bits = torch.where( + (scale_code & 0x7F) == 0x7F, + torch.full_like(prod_bits, 0x7E00), + prod_bits, + ) + + sign_f32 = torch.where( + (prod_bits & 0x8000) != 0, + torch.tensor(-1.0, device=prod_bits.device, dtype=torch.float32), + torch.tensor(1.0, device=prod_bits.device, dtype=torch.float32), + ) + fp16_exp = (prod_bits >> 10) & 0x1F + fp16_frac = prod_bits & 0x3FF + normal_f32 = torch.ldexp((fp16_frac + 1024).to(torch.float32), fp16_exp - 25) + subnormal_f32 = torch.ldexp(fp16_frac.to(torch.float32), fp16_exp - 24) + return sign_f32 * torch.where(fp16_exp == 0, subnormal_f32, normal_f32) + + @staticmethod + def _sum_4over6_2d_error(err: torch.Tensor, tile_len_y: int) -> torch.Tensor: + """Reduce 16 row errors in the same tree order as the CUDA warp reduction.""" + rows = err.view(err.shape[0] // tile_len_y, tile_len_y, err.shape[1], 1) + rows = rows.squeeze(-1) + rows = rows[:, 0:8, :] + rows[:, 8:16, :] + rows = rows[:, 0:4, :] + rows[:, 4:8, :] + rows = rows[:, 0:2, :] + rows[:, 2:4, :] + return (rows[:, 0, :] + rows[:, 1, :]).unsqueeze(-1) + @staticmethod def _quantize_blockwise_4over6_reference( x: torch.Tensor, @@ -527,41 +586,59 @@ def _quantize_blockwise_4over6_reference( qx_map4 = cast_to_fp4x2(clipped_x_map4) qx_map6 = cast_to_fp4x2(clipped_x_map6) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) - denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX - sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) - sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) - if row_scaled_nvfp4: - error_global_amax = global_amax.squeeze(-1) - else: - error_global_amax = global_amax x_float = x.to(torch.float32) - err_map4 = torch.zeros_like(vec_max) - err_map6 = torch.zeros_like(vec_max) - for idx in range(tile_len_x): - val_map4 = fp4_map4[:, :, idx] * sf_map4 - val_map4 = val_map4 * error_global_amax - val_map4 = val_map4 / denom - diff_map4 = val_map4 - x_float[:, :, idx] - if nvfp4_4over6_err_mode == "MSE": - err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) - else: - err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) - - val_map6 = fp4_map6[:, :, idx] * sf_map6 - val_map6 = val_map6 * error_global_amax - val_map6 = val_map6 / denom - diff_map6 = val_map6 - x_float[:, :, idx] - if nvfp4_4over6_err_mode == "MSE": - err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + if nvfp4_4over6_err_mode in ("MAE_FP16", "MSE_FP16"): + original_scaled = x_float * global_encode_scale + candidate_map4 = NVFP4QuantizerRef._ref_nvfp4_4over6_fp16_candidate( + fp4_map4, decode_scale_map4 + ) + candidate_map6 = NVFP4QuantizerRef._ref_nvfp4_4over6_fp16_candidate( + fp4_map6, decode_scale_map6 + ) + for idx in range(tile_len_x): + diff_map4 = candidate_map4[:, :, idx] - original_scaled[:, :, idx] + diff_map6 = candidate_map6[:, :, idx] - original_scaled[:, :, idx] + if nvfp4_4over6_err_mode == "MSE_FP16": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + else: + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + error_global_amax = global_amax.squeeze(-1) else: - err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + error_global_amax = global_amax + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * error_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * error_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) if tile_len_y == 1: pick_map4 = err_map4 < err_map6 else: - err_map4_blocks = err_map4.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) - err_map6_blocks = err_map6.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + err_map4_blocks = NVFP4QuantizerRef._sum_4over6_2d_error(err_map4, tile_len_y) + err_map6_blocks = NVFP4QuantizerRef._sum_4over6_2d_error(err_map6, tile_len_y) pick_map4 = (err_map4_blocks < err_map6_blocks).repeat_interleave(tile_len_y, dim=0) qx = torch.where( pick_map4.expand(-1, -1, tile_len_x // 2), diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 24962d67f2..aa249706d0 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -172,8 +172,10 @@ def __init__( if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): - raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") + if self.nvfp4_4over6_err_mode not in ("MAE", "MSE", "MAE_FP16", "MSE_FP16"): + raise ValueError( + "nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." + ) self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) From 71ba85e293f95bccb0af3ece607d79c10509153b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 31 May 2026 23:37:11 -0700 Subject: [PATCH 02/11] Clean up NVFP4 4over6 FP16 mode changes Signed-off-by: Ziang Li --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 14 +------------- .../common/cast/nvfp4/quantize_4over6_nvfp4.cuh | 5 ++--- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 7a9df47823..d6ab4b6740 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -1165,10 +1165,6 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) name += "XMSE"; } else if (config.mode == kNVTENVFP44Over6MinMAE) { name += "XMAE"; - } else if (config.mode == kNVTENVFP44Over6MinMSEFP16) { - name += "XMSE_FP16"; - } else if (config.mode == kNVTENVFP44Over6MinMAEFP16) { - name += "XMAE_FP16"; } else { name += "XINVALID_MODE"; } @@ -1223,18 +1219,10 @@ INSTANTIATE_TEST_SUITE_P( NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 448, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, false}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, true}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 448, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 448, true}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 448, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 448, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, false}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, true}, NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, true}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 256, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAEFP16, 256, true}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 256, false}, - NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSEFP16, 256, true})), // four_over_six_config + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, true})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index e00f43b696..3f402e0163 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -93,8 +93,6 @@ template struct Config { static constexpr NVTENVFP44Over6Mode mode = kMode; static constexpr bool err_use_fast_math = kErrUseFastMath; - static constexpr bool fp16_error = - kMode == kNVTENVFP44Over6MinMAEFP16 || kMode == kNVTENVFP44Over6MinMSEFP16; }; struct Candidate { @@ -320,7 +318,8 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( "Try recompiling with sm_XXXa instead of sm_XXX."); } - if constexpr (Cfg::fp16_error) { + if constexpr (Cfg::mode == kNVTENVFP44Over6MinMAEFP16 || + Cfg::mode == kNVTENVFP44Over6MinMSEFP16) { accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, err); From 5136b58671141a130990d3f2a7314a2cd6a6b3bf Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 1 Jun 2026 00:49:21 -0700 Subject: [PATCH 03/11] Clarify NVFP4 4over6 reference error modes Signed-off-by: Ziang Li --- .../pytorch/custom_recipes/quantization_ref_nvfp4.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index d724b41183..5eb12431fb 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -516,6 +516,7 @@ def _ref_nvfp4_4over6_fp16_candidate(q: torch.Tensor, scale: torch.Tensor) -> to @staticmethod def _sum_4over6_2d_error(err: torch.Tensor, tile_len_y: int) -> torch.Tensor: """Reduce 16 row errors in the same tree order as the CUDA warp reduction.""" + assert tile_len_y == 16, "NVFP4 4over6 2D error reduction expects 16 rows." rows = err.view(err.shape[0] // tile_len_y, tile_len_y, err.shape[1], 1) rows = rows.squeeze(-1) rows = rows[:, 0:8, :] + rows[:, 8:16, :] @@ -538,8 +539,9 @@ def _quantize_blockwise_4over6_reference( """Quantize NVFP4 with 4over6 candidate selection. This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, - the configured error is computed in the original input domain with the - selected global E4M3 denominator, and ties choose map-to-6. + MAE/MSE compute error in the original input domain, MAE_FP16/MSE_FP16 + compute error in the E4M3-scaled FP16 product domain, and ties choose + map-to-6. """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x @@ -730,7 +732,7 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) if nvfp4_use_4over6: # FourOverSix compares map-to-4 and map-to-6 candidates using - # the configured original input-domain error, while keeping TE-style FP4 + # the configured error mode, while keeping TE-style FP4 # quantization for each candidate. return cls._quantize_blockwise_4over6_reference( x, From 0a6e0d6a3eda92fc9506401fd4f9e9b512bc709c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 1 Jun 2026 22:08:17 -0700 Subject: [PATCH 04/11] Refactor NVFP4 4over6 error mode interface Signed-off-by: Ziang Li --- docs/envvars.rst | 6 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 2 +- .../nvfp4/test_nvfp4_quantize_exact.py | 160 ++++++++++++------ tests/pytorch/test_recipe.py | 4 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 90 +++------- .../transformer_engine/transformer_engine.h | 8 +- transformer_engine/common/recipe/__init__.py | 6 +- .../common/transformer_engine.cpp | 4 +- .../pytorch/csrc/extensions/cast.cpp | 8 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- .../custom_recipes/quantization_ref_nvfp4.py | 23 ++- .../pytorch/tensor/nvfp4_tensor.py | 6 +- 12 files changed, 172 insertions(+), 153 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 1c0acb61ff..044a7f6a0d 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -301,15 +301,15 @@ Kernel Configuration .. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE - :Type: ``str`` (``MAE``, ``MSE``, ``MAE_FP16``, or ``MSE_FP16``) + :Type: ``str`` (``MAE`` or ``MSE``) :Default: ``MAE`` - :Description: Select the error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. ``MAE`` and ``MSE`` compare dequantized candidates in the original input domain. ``MAE_FP16`` and ``MSE_FP16`` compare candidates in the E4M3-scaled domain after the E2M1 x E4M3 product is rounded to FP16. + :Description: Select the error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. .. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. + :Description: Use the faster NVFP4 4over6 candidate error path that compares candidates in the E4M3-scaled domain after the E2M1 x E4M3 product is rounded to FP16. Error differences and accumulation remain FP32. By default, 4over6 error comparison uses the original input-domain path; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index d6ab4b6740..18b96fa6df 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -785,7 +785,7 @@ void performTest(float (*OP)(const float), if (use_4over6 && use_fast_math) { std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " "Use use_4over6_err_use_fast_math to test the 4over6 candidate " - "error fast-math path." + "FP16 product-domain error path." << std::endl; } diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 5a6f0d104b..cada5a43c8 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -2,6 +2,9 @@ # # See LICENSE for license information. +import os +from contextlib import contextmanager + import pytest import torch import transformer_engine.pytorch as te @@ -16,6 +19,19 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +@contextmanager +def nvfp4_4over6_err_fast_math(enabled: bool): + old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + try: + yield + finally: + if old_value is None: + os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) + else: + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value + + def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, @@ -55,6 +71,7 @@ def check_quantization_nvfp4_versus_reference( use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, ) -> None: if nvfp4_e4m3_max != 448 and not use_4over6: pytest.skip("E4M3 max 256 is only meaningful for 4over6") @@ -73,27 +90,28 @@ def check_quantization_nvfp4_versus_reference( x = torch.randn((M, N), dtype=x_dtype, device=device) # Quantize - nvfp4_quantizer = NVFP4Quantizer( - fp4_dtype=te_dtype, - rowwise=True, - columnwise=return_transpose, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, - ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) - else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False + with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) # Extract data from NVFP4Tensor assert x_nvfp4_sut._rowwise_data is not None @@ -122,6 +140,7 @@ def check_quantization_nvfp4_versus_reference( nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -201,8 +220,8 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) @pytest.mark.parametrize( "nvfp4_4over6_err_mode", - ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], - ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], ) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -232,6 +251,46 @@ def test_quantization_block_tiling_versus_reference( ) +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", [(128, 128), (256, 256)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize( + "nvfp4_4over6_err_mode", + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], +) +def test_nvfp4_4over6_fp16_error_scoring_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + row_scaled_nvfp4: bool, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_transpose=return_transpose, + swizzled_scale=False, + use_cpp_allocator=use_cpp_allocator, + with_2d_quantization=False, + row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=True, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=True, + ) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", @@ -249,8 +308,8 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "nvfp4_4over6_err_mode", - ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], - ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], ) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, @@ -292,13 +351,14 @@ def test_nvfp4_quantization_extrema_versus_reference( nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) - else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + with nvfp4_4over6_err_fast_math(False): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) @@ -370,8 +430,8 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "nvfp4_4over6_err_mode", - ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], - ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], ) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, @@ -426,13 +486,14 @@ def test_nvfp4_quantization_boundary_values( nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x) - else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - (M, N), dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + with nvfp4_4over6_err_fast_math(False): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) @@ -504,8 +565,8 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "nvfp4_4over6_err_mode", - ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], - ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], ) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, @@ -546,13 +607,14 @@ def test_nvfp4_quantization_noncontiguous_inputs( nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - if use_cpp_allocator: - x_nvfp4_sut = nvfp4_quantizer(x_nc) - else: - x_nvfp4_sut = nvfp4_quantizer.make_empty( - x_nc.shape, dtype=x_dtype, device=device, requires_grad=False - ) - x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + with nvfp4_4over6_err_fast_math(False): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) assert x_nvfp4_sut._rowwise_data is not None qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index ecc5294f36..a2050d43d8 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -527,8 +527,8 @@ def test_quantizer_update(self, module_class): ) @pytest.mark.parametrize( "nvfp4_4over6_err_mode", - ["MAE", "MSE", "MAE_FP16", "MSE_FP16"], - ids=["mae_err", "mse_err", "mae_fp16_err", "mse_fp16_err"], + ["MAE", "MSE"], + ids=["mae_err", "mse_err"], ) def test_nvfp4_row_scaled_quantizer_roles( nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 3f402e0163..e5b53207f1 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -39,27 +39,19 @@ namespace nvfp4 { #if FP4_TYPE_SUPPORTED -#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \ - switch (MODE) { \ - case kNVTENVFP44Over6MinMAE: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \ - { __VA_ARGS__ } \ - } break; \ - case kNVTENVFP44Over6MinMSE: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \ - { __VA_ARGS__ } \ - } break; \ - case kNVTENVFP44Over6MinMAEFP16: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAEFP16; \ - { __VA_ARGS__ } \ - } break; \ - case kNVTENVFP44Over6MinMSEFP16: { \ - constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSEFP16; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \ - } \ +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \ + switch (MODE) { \ + case kNVTENVFP44Over6MinMAE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6MinMSE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \ + } \ } #define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \ @@ -115,21 +107,9 @@ struct ScalePair { template __device__ __forceinline__ float compute_error_rn(const float diff) { - if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) { + if constexpr (kMode == kNVTENVFP44Over6MinMSE) { return __fmul_rn(diff, diff); - } else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) { - return fabsf(diff); - } else { - NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); - return fabsf(diff); - } -} - -template -__device__ __forceinline__ float compute_error(const float diff) { - if constexpr (kMode == kNVTENVFP44Over6MinMSE || kMode == kNVTENVFP44Over6MinMSEFP16) { - return diff * diff; - } else if constexpr (kMode == kNVTENVFP44Over6MinMAE || kMode == kNVTENVFP44Over6MinMAEFP16) { + } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { return fabsf(diff); } else { NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); @@ -210,18 +190,10 @@ __device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_ constexpr float fp8_max = static_cast(E4M3_MAX); constexpr float err_denom = fp4_max * fp8_max; const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; - - if constexpr (Cfg::err_use_fast_math) { - const float dequant = __half2float(__ushort_as_half(half_bits)); - const float val = dequant * sf * global_amax / err_denom; - const float diff = val - x; - *err += compute_error(diff); - } else { - const float dequant = __half2float(__ushort_as_half(half_bits)); - const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); - const float diff = __fsub_rn(val, x); - *err = __fadd_rn(*err, compute_error_rn(diff)); - } + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); + const float diff = __fsub_rn(val, x); + *err = __fadd_rn(*err, compute_error_rn(diff)); } __device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { @@ -264,21 +236,12 @@ __device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t const float global_encode_scale, float *err) { const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf); - if constexpr (Cfg::err_use_fast_math) { - const float original0 = x0 * global_encode_scale; - const float original1 = x1 * global_encode_scale; - const float diff0 = candidate.x - original0; - const float diff1 = candidate.y - original1; - *err += compute_error(diff0); - *err += compute_error(diff1); - } else { - const float original0 = __fmul_rn(x0, global_encode_scale); - const float original1 = __fmul_rn(x1, global_encode_scale); - const float diff0 = __fsub_rn(candidate.x, original0); - const float diff1 = __fsub_rn(candidate.y, original1); - *err = __fadd_rn(*err, compute_error_rn(diff0)); - *err = __fadd_rn(*err, compute_error_rn(diff1)); - } + const float original0 = __fmul_rn(x0, global_encode_scale); + const float original1 = __fmul_rn(x1, global_encode_scale); + const float diff0 = __fsub_rn(candidate.x, original0); + const float diff1 = __fsub_rn(candidate.y, original1); + *err = __fadd_rn(*err, compute_error_rn(diff0)); + *err = __fadd_rn(*err, compute_error_rn(diff1)); } template @@ -318,8 +281,7 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( "Try recompiling with sm_XXXa instead of sm_XXX."); } - if constexpr (Cfg::mode == kNVTENVFP44Over6MinMAEFP16 || - Cfg::mode == kNVTENVFP44Over6MinMSEFP16) { + if constexpr (Cfg::err_use_fast_math) { accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, err); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 491a1978be..f675b2f535 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -122,11 +122,9 @@ enum NVTEScalingMode { * \brief Method for NVFP4 4over6 quantization. */ enum NVTENVFP44Over6Mode { - kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */ - kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */ - kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */ - kNVTENVFP44Over6MinMAEFP16 = 3, /*!< Select with lower absolute error in FP16 domain */ - kNVTENVFP44Over6MinMSEFP16 = 4, /*!< Select with lower squared error in FP16 domain */ + kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */ + kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */ + kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */ }; /*! \brief TE Tensor type diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ab206320b0..8a03f2f51a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -14,7 +14,7 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") _NVFP4_4OVER6_SCOPES = ("none", "weights", "activations", "all") -_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE", "MAE_FP16", "MSE_FP16") +_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE") class _FormatHelper(NamedTuple): @@ -535,7 +535,7 @@ class NVFP4BlockScaling(Recipe): Select 4over6 tensors that use 256 as the global E4M3 scale bound. By default, all 4over6 tensors use 256. Use ``'none'`` to keep the standard NVFP4 448 bound for 4over6 tensors. - nvfp4_4over6_err_mode : {'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'}, default = 'MAE' + nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, @@ -577,7 +577,7 @@ def __post_init__(self) -> None: ), "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be one of: 'none', 'weights', 'activations', 'all'." assert ( self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES - ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." + ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 0035772bf2..b3179d38fd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1109,9 +1109,7 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, const auto val = *reinterpret_cast(buf); NVTE_CHECK(val == static_cast(kNVTENVFP44Over6Disabled) || val == static_cast(kNVTENVFP44Over6MinMAE) || - val == static_cast(kNVTENVFP44Over6MinMSE) || - val == static_cast(kNVTENVFP44Over6MinMAEFP16) || - val == static_cast(kNVTENVFP44Over6MinMSEFP16), + val == static_cast(kNVTENVFP44Over6MinMSE), "Invalid NVFP4 4over6 mode (got ", static_cast(val), ")"); config_.nvfp4_4over6_mode = static_cast(val); break; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d1a9cd8587..221af734ae 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1057,8 +1057,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 - // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 + // FP16 product-domain candidate error path. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !nvfp4_use_4over6) { for (auto &config : quant_config_list) { @@ -1227,8 +1227,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, config.set_nvfp4_4over6_mode(quantizer.nvfp4_4over6_mode); } - // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 + // FP16 product-domain candidate error path. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !nvfp4_use_4over6) { for (auto &config : quant_config_list) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d81fc345f8..cd6d058ddd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1740,10 +1740,6 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAE; } else if (nvfp4_4over6_err_mode == "MSE") { this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSE; - } else if (nvfp4_4over6_err_mode == "MAE_FP16") { - this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAEFP16; - } else if (nvfp4_4over6_err_mode == "MSE_FP16") { - this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSEFP16; } else { NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode); } @@ -2467,8 +2463,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 - // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 + // FP16 product-domain candidate error path. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && this->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled) { quant_config.set_use_fast_math(true); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 5eb12431fb..b1efb09acf 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -355,14 +355,13 @@ def __init__( nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if nvfp4_4over6_err_mode not in ("MAE", "MSE", "MAE_FP16", "MSE_FP16"): - raise ValueError( - "nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." - ) + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -388,6 +387,7 @@ def __init__( if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode + self.nvfp4_4over6_err_use_fast_math = nvfp4_4over6_err_use_fast_math self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -534,14 +534,15 @@ def _quantize_blockwise_4over6_reference( row_scaled_nvfp4: bool, tile_len_y: int, nvfp4_4over6_err_mode: str, + nvfp4_4over6_err_use_fast_math: bool, nvfp4_e4m3_max: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, - MAE/MSE compute error in the original input domain, MAE_FP16/MSE_FP16 - compute error in the E4M3-scaled FP16 product domain, and ties choose - map-to-6. + MAE/MSE compute error in the original input domain by default, the + fast-math error path computes error in the E4M3-scaled FP16 product + domain, and ties choose map-to-6. """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x @@ -593,7 +594,7 @@ def _quantize_blockwise_4over6_reference( fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) x_float = x.to(torch.float32) - if nvfp4_4over6_err_mode in ("MAE_FP16", "MSE_FP16"): + if nvfp4_4over6_err_use_fast_math: original_scaled = x_float * global_encode_scale candidate_map4 = NVFP4QuantizerRef._ref_nvfp4_4over6_fp16_candidate( fp4_map4, decode_scale_map4 @@ -604,7 +605,7 @@ def _quantize_blockwise_4over6_reference( for idx in range(tile_len_x): diff_map4 = candidate_map4[:, :, idx] - original_scaled[:, :, idx] diff_map6 = candidate_map6[:, :, idx] - original_scaled[:, :, idx] - if nvfp4_4over6_err_mode == "MSE_FP16": + if nvfp4_4over6_err_mode == "MSE": err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) else: @@ -663,6 +664,7 @@ def _quantize_blockwise_reference( nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", + nvfp4_4over6_err_use_fast_math: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -743,6 +745,7 @@ def _quantize_blockwise_reference( row_scaled_nvfp4, tile_len_y, nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math, nvfp4_e4m3_max, ) @@ -909,6 +912,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=self.nvfp4_4over6_err_use_fast_math, eps=self.eps, ) if transpose_scales: @@ -935,6 +939,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, + nvfp4_4over6_err_use_fast_math=self.nvfp4_4over6_err_use_fast_math, eps=self.eps, ) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index aa249706d0..cc3915783a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -172,10 +172,8 @@ def __init__( if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if self.nvfp4_4over6_err_mode not in ("MAE", "MSE", "MAE_FP16", "MSE_FP16"): - raise ValueError( - "nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE', 'MAE_FP16', 'MSE_FP16'." - ) + if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) From 25861a15f6124410be7e82681c8bd4739c54fe23 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 1 Jun 2026 22:23:01 -0700 Subject: [PATCH 05/11] Clean up NVFP4 4over6 exact test configs Signed-off-by: Ziang Li --- .../nvfp4/test_nvfp4_quantize_exact.py | 227 +++++++++++------- .../custom_recipes/quantization_ref_nvfp4.py | 4 +- 2 files changed, 141 insertions(+), 90 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index cada5a43c8..7a21db2d64 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -4,6 +4,7 @@ import os from contextlib import contextmanager +from dataclasses import dataclass import pytest import torch @@ -19,6 +20,46 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +@dataclass(frozen=True) +class NVFP44Over6TestConfig: + id: str + use_4over6: bool = True + e4m3_max: int = 448 + err_mode: str = "MAE" + err_use_fast_math: bool = False + + +NVFP4_4OVER6_CONFIGS = [ + NVFP44Over6TestConfig(id="nvfp4", use_4over6=False), + NVFP44Over6TestConfig(id="4over6-mae-e4m3-448-exact", err_mode="MAE"), + NVFP44Over6TestConfig( + id="4over6-mae-e4m3-448-err-fast", + err_mode="MAE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mae-e4m3-256-exact", e4m3_max=256, err_mode="MAE"), + NVFP44Over6TestConfig( + id="4over6-mae-e4m3-256-err-fast", + e4m3_max=256, + err_mode="MAE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mse-e4m3-448-exact", err_mode="MSE"), + NVFP44Over6TestConfig( + id="4over6-mse-e4m3-448-err-fast", + err_mode="MSE", + err_use_fast_math=True, + ), + NVFP44Over6TestConfig(id="4over6-mse-e4m3-256-exact", e4m3_max=256, err_mode="MSE"), + NVFP44Over6TestConfig( + id="4over6-mse-e4m3-256-err-fast", + e4m3_max=256, + err_mode="MSE", + err_use_fast_math=True, + ), +] + + @contextmanager def nvfp4_4over6_err_fast_math(enabled: bool): old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") @@ -90,7 +131,30 @@ def check_quantization_nvfp4_versus_reference( x = torch.randn((M, N), dtype=x_dtype, device=device) # Quantize - with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): + if use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + else: nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -216,12 +280,10 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) @pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, ) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -232,9 +294,7 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_e4m3_max: int, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -245,49 +305,10 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, - ) - - -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize("M, N", [(128, 128), (256, 256)]) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) -@pytest.mark.parametrize( - "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] -) -@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) -@pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], -) -def test_nvfp4_4over6_fp16_error_scoring_versus_reference( - x_dtype: torch.dtype, - M: int, - N: int, - return_transpose: bool, - use_cpp_allocator: bool, - row_scaled_nvfp4: bool, - nvfp4_e4m3_max: int, - nvfp4_4over6_err_mode: str, -) -> None: - check_quantization_nvfp4_versus_reference( - x_dtype=x_dtype, - M=M, - N=N, - return_transpose=return_transpose, - swizzled_scale=False, - use_cpp_allocator=use_cpp_allocator, - with_2d_quantization=False, - row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=True, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, - nvfp4_4over6_err_use_fast_math=True, + use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) @@ -305,11 +326,10 @@ def test_nvfp4_4over6_fp16_error_scoring_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, ) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, @@ -319,11 +339,10 @@ def test_nvfp4_quantization_extrema_versus_reference( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -347,11 +366,21 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - with nvfp4_4over6_err_fast_math(False): + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + else: if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) else: @@ -381,8 +410,10 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -427,11 +458,10 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, ) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, @@ -440,8 +470,7 @@ def test_nvfp4_quantization_boundary_values( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -449,7 +478,7 @@ def test_nvfp4_quantization_boundary_values( Validates native vs reference byte-for-byte and scale parity. """ maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -482,11 +511,21 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - with nvfp4_4over6_err_fast_math(False): + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + else: if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) else: @@ -516,8 +555,10 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -562,11 +603,10 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], + "nvfp4_4over6_config", + NVFP4_4OVER6_CONFIGS, + ids=lambda config: config.id, ) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, @@ -575,11 +615,10 @@ def test_nvfp4_quantization_noncontiguous_inputs( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, - use_4over6: bool, - nvfp4_4over6_err_mode: str, + nvfp4_4over6_config: NVFP44Over6TestConfig, ): maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + row_scaled_nvfp4, return_transpose, use_4over6=nvfp4_4over6_config.use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -603,11 +642,21 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, ) - with nvfp4_4over6_err_fast_math(False): + if nvfp4_4over6_config.use_4over6: + with nvfp4_4over6_err_fast_math(nvfp4_4over6_config.err_use_fast_math): + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + else: if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x_nc) else: @@ -637,8 +686,10 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + nvfp4_use_4over6=nvfp4_4over6_config.use_4over6, + nvfp4_e4m3_max=nvfp4_4over6_config.e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_config.err_mode, + nvfp4_4over6_err_use_fast_math=nvfp4_4over6_config.err_use_fast_math, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index b1efb09acf..d09b95ace3 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -360,8 +360,6 @@ def __init__( with_random_sign_mask: bool = True, ): nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() - if nvfp4_4over6_err_mode not in ("MAE", "MSE"): - raise ValueError("nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -370,6 +368,8 @@ def __init__( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) if nvfp4_use_4over6: + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError(f"Unsupported NVFP4 4over6 error mode: {nvfp4_4over6_err_mode}.") if pow_2_scales: raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") if quant_tile_shape not in ((1, 16), (16, 16)): From 6fce6b9417ec3e69e17818e44b4da4f6b7ed331e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 1 Jun 2026 22:25:53 -0700 Subject: [PATCH 06/11] Lift NVFP4 exact test quantizer construction Signed-off-by: Ziang Li --- .../nvfp4/test_nvfp4_quantize_exact.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7a21db2d64..bff6e046ae 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -130,23 +130,24 @@ def check_quantization_nvfp4_versus_reference( # Input x = torch.randn((M, N), dtype=x_dtype, device=device) + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + ) + # Quantize if use_4over6: with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): - nvfp4_quantizer = NVFP4Quantizer( - fp4_dtype=te_dtype, - rowwise=True, - columnwise=return_transpose, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, - ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) else: @@ -155,20 +156,6 @@ def check_quantization_nvfp4_versus_reference( ) x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) else: - nvfp4_quantizer = NVFP4Quantizer( - fp4_dtype=te_dtype, - rowwise=True, - columnwise=return_transpose, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - row_scaled_nvfp4=row_scaled_nvfp4, - nvfp4_use_4over6=use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, - ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) else: From f47fb7b8d8dc1aad0ffa6347d04d6939ed7004a9 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 1 Jun 2026 22:28:59 -0700 Subject: [PATCH 07/11] Clean up doc changes Signed-off-by: Ziang Li --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 2 +- tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py | 2 +- tests/pytorch/test_recipe.py | 6 +----- transformer_engine/pytorch/csrc/extensions/cast.cpp | 8 ++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 4 ++-- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 +- 6 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 18b96fa6df..d6ab4b6740 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -785,7 +785,7 @@ void performTest(float (*OP)(const float), if (use_4over6 && use_fast_math) { std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " "Use use_4over6_err_use_fast_math to test the 4over6 candidate " - "FP16 product-domain error path." + "error fast-math path." << std::endl; } diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bff6e046ae..ea60bd3837 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -130,6 +130,7 @@ def check_quantization_nvfp4_versus_reference( # Input x = torch.randn((M, N), dtype=x_dtype, device=device) + # Quantize nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -145,7 +146,6 @@ def check_quantization_nvfp4_versus_reference( nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) - # Quantize if use_4over6: with nvfp4_4over6_err_fast_math(nvfp4_4over6_err_use_fast_math): if use_cpp_allocator: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index a2050d43d8..9a14cee7fd 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -525,11 +525,7 @@ def test_quantizer_update(self, module_class): ["none", "weights", "activations", "all"], ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], ) -@pytest.mark.parametrize( - "nvfp4_4over6_err_mode", - ["MAE", "MSE"], - ids=["mae_err", "mse_err"], -) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_quantizer_roles( nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode ): diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 221af734ae..d1a9cd8587 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1057,8 +1057,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 - // FP16 product-domain candidate error path. + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !nvfp4_use_4over6) { for (auto &config : quant_config_list) { @@ -1227,8 +1227,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, config.set_nvfp4_4over6_mode(quantizer.nvfp4_4over6_mode); } - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 - // FP16 product-domain candidate error path. + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !nvfp4_use_4over6) { for (auto &config : quant_config_list) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cd6d058ddd..bc87b54ba8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2463,8 +2463,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 - // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH selects the NVFP4 4over6 - // FP16 product-domain candidate error path. + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && this->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled) { quant_config.set_use_fast_math(true); diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cc3915783a..24962d67f2 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -173,7 +173,7 @@ def __init__( raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): - raise ValueError("nvfp4_4over6_err_mode must be one of: 'MAE', 'MSE'.") + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) From 54797b31de8886572fc478ee35899f56322c63bd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 2 Jun 2026 16:36:41 -0700 Subject: [PATCH 08/11] Add script and experimental refactor Signed-off-by: Ziang Li --- .../linear/compare_nvfp4_4over6_selection.py | 111 ++++++++++++++++++ .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 89 ++++++++------ 2 files changed, 165 insertions(+), 35 deletions(-) create mode 100644 benchmarks/linear/compare_nvfp4_4over6_selection.py diff --git a/benchmarks/linear/compare_nvfp4_4over6_selection.py b/benchmarks/linear/compare_nvfp4_4over6_selection.py new file mode 100644 index 0000000000..7f73dfce9e --- /dev/null +++ b/benchmarks/linear/compare_nvfp4_4over6_selection.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Compare NVFP4 4over6 E4M3 scales with and without error fast math.""" + +import os +from contextlib import contextmanager + +import torch +from transformer_engine.pytorch import NVFP4Quantizer +import transformer_engine_torch as tex + + +M, K = 98304, 7168 + + +@contextmanager +def _error_fast_math(enabled: bool): + old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + try: + yield + finally: + if old_value is None: + os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) + else: + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value + + +def _quantize_scale_bytes( + x: torch.Tensor, + err_mode: str, + err_fast_math: bool, + row_scaled: bool, + with_2d_quantization: bool, + nvfp4_e4m3_max: int, +) -> torch.Tensor: + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled, + nvfp4_use_4over6=True, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=err_mode, + ) + with _error_fast_math(err_fast_math): + quantized = quantizer(x) + assert quantized._rowwise_scale_inv is not None + return quantized._rowwise_scale_inv.contiguous().view(torch.uint8) + + +def _compare_e4m3( + x: torch.Tensor, + dtype_name: str, + scale_mode: str, + row_scaled: bool, + quant_mode: str, + with_2d_quantization: bool, + nvfp4_e4m3_max: int, +) -> None: + for err_mode in ("MAE", "MSE"): + regular = _quantize_scale_bytes( + x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max + ) + fast = _quantize_scale_bytes( + x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max + ) + same = torch.count_nonzero(regular == fast).item() + total = regular.numel() + print( + f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} " + f"{dtype_name:>5} {err_mode:>3} " + f"{100.0 * same / total:12.6f} {total - same:15d} {total}" + ) + + +def main(): + torch.set_grad_enabled(False) + print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}") + print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3") + for scale_mode, row_scaled, quant_mode, with_2d_quantization in ( + ("tensor", False, "1d", False), + ("tensor", False, "2d", True), + ("row", True, "1d", False), + ): + for nvfp4_e4m3_max in (256, 448): + for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")): + torch.manual_seed(1234) + x = torch.randn((M, K), dtype=dtype, device="cuda") + _compare_e4m3( + x, + dtype_name, + scale_mode, + row_scaled, + quant_mode, + with_2d_quantization, + nvfp4_e4m3_max, + ) + del x + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index e5b53207f1..50776a3ed6 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -105,6 +105,11 @@ struct ScalePair { float global_encode_scale; }; +struct FP16ErrorScalePair { + uint32_t map4; + uint32_t map6; +}; + template __device__ __forceinline__ float compute_error_rn(const float diff) { if constexpr (kMode == kNVTENVFP44Over6MinMSE) { @@ -200,42 +205,50 @@ __device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { return *reinterpret_cast(&sf); } -__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte, - const nvfp4_scale_t sf) { - float2 result; - const uint32_t sf_byte = static_cast(fp8_bits(sf)); +__device__ __forceinline__ FP16ErrorScalePair compute_fp16_error_scales(const ScalePair &scales) { + FP16ErrorScalePair result; + const uint32_t packed_scales = static_cast(fp8_bits(scales.map4)) | + (static_cast(fp8_bits(scales.map6)) << 8); asm volatile( "{\n" - ".reg .b8 byte0, byte1, byte2, byte3;\n" ".reg .b16 fp8_pair;\n" - ".reg .b16 scale_h, unused_h;\n" - ".reg .b16 lo, hi;\n" - ".reg .b32 q_h2;\n" + ".reg .b16 map4_h, map6_h;\n" ".reg .b32 scale_h2;\n" - ".reg .b32 prod_h2;\n" - "mov.b32 {byte0, byte1, byte2, byte3}, %2;\n" - "cvt.rn.f16x2.e2m1x2 q_h2, byte0;\n" - "cvt.u16.u32 fp8_pair, %3;\n" + "cvt.u16.u32 fp8_pair, %2;\n" "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" - "mov.b32 {scale_h, unused_h}, scale_h2;\n" - "mov.b32 scale_h2, {scale_h, scale_h};\n" - "mul.rn.f16x2 prod_h2, q_h2, scale_h2;\n" + "mov.b32 {map4_h, map6_h}, scale_h2;\n" + "mov.b32 %0, {map4_h, map4_h};\n" + "mov.b32 %1, {map6_h, map6_h};\n" + "}" + : "=r"(result.map4), "=r"(result.map6) + : "r"(packed_scales)); + return result; +} + +__device__ __forceinline__ float2 f16x2_scaled_to_float2(const uint32_t q_h2, + const uint32_t scale_h2) { + float2 result; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + ".reg .b32 prod_h2;\n" + "mul.rn.f16x2 prod_h2, %2, %3;\n" "mov.b32 {lo, hi}, prod_h2;\n" "cvt.f32.f16 %0, lo;\n" "cvt.f32.f16 %1, hi;\n" "}" : "=f"(result.x), "=f"(result.y) - : "r"(e2m1_byte), "r"(sf_byte)); + : "r"(q_h2), "r"(scale_h2)); return result; } template -__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t e2m1_byte, +__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t q_h2, const float x0, const float x1, - const nvfp4_scale_t sf, + const uint32_t scale_h2, const float global_encode_scale, float *err) { - const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf); + const float2 candidate = f16x2_scaled_to_float2(q_h2, scale_h2); const float original0 = __fmul_rn(x0, global_encode_scale); const float original1 = __fmul_rn(x1, global_encode_scale); const float diff0 = __fsub_rn(candidate.x, original0); @@ -247,7 +260,8 @@ __device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf, - const float global_amax, const float global_encode_scale, float *err) { + const uint32_t fp16_error_scale, const float global_amax, const float global_encode_scale, + float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -282,13 +296,14 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( } if constexpr (Cfg::err_use_fast_math) { - accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); - accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, - err); - accumulate_fp16_scaled_error_pair((out >> 16) & 0xFFu, x[4], x[5], sf, global_encode_scale, - err); - accumulate_fp16_scaled_error_pair((out >> 24) & 0xFFu, x[6], x[7], sf, global_encode_scale, - err); + accumulate_fp16_scaled_error_pair(out_dequant_1, x[0], x[1], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_2, x[2], x[3], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_3, x[4], x[5], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_4, x[6], x[7], fp16_error_scale, + global_encode_scale, err); } else { const float sf_float = static_cast(sf); accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); @@ -310,18 +325,22 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c CandidatePair candidates; candidates.map4.err = 0.0f; candidates.map6.err = 0.0f; + FP16ErrorScalePair fp16_error_scales{}; + if constexpr (Cfg::err_use_fast_math) { + fp16_error_scales = compute_fp16_error_scales(scales); + } candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, - &candidates.map4.err); + x0, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, - &candidates.map6.err); + x0, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, - &candidates.map4.err); + x1, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, - &candidates.map6.err); + x1, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); return candidates; } From 8091a597c01c4298933f5e374eccb64eb6ef2704 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 2 Jun 2026 16:37:49 -0700 Subject: [PATCH 09/11] Revert "Add script and experimental refactor" This reverts commit 54797b31de8886572fc478ee35899f56322c63bd. Signed-off-by: Ziang Li --- .../linear/compare_nvfp4_4over6_selection.py | 111 ------------------ .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 89 ++++++-------- 2 files changed, 35 insertions(+), 165 deletions(-) delete mode 100644 benchmarks/linear/compare_nvfp4_4over6_selection.py diff --git a/benchmarks/linear/compare_nvfp4_4over6_selection.py b/benchmarks/linear/compare_nvfp4_4over6_selection.py deleted file mode 100644 index 7f73dfce9e..0000000000 --- a/benchmarks/linear/compare_nvfp4_4over6_selection.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Compare NVFP4 4over6 E4M3 scales with and without error fast math.""" - -import os -from contextlib import contextmanager - -import torch -from transformer_engine.pytorch import NVFP4Quantizer -import transformer_engine_torch as tex - - -M, K = 98304, 7168 - - -@contextmanager -def _error_fast_math(enabled: bool): - old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" - try: - yield - finally: - if old_value is None: - os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) - else: - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value - - -def _quantize_scale_bytes( - x: torch.Tensor, - err_mode: str, - err_fast_math: bool, - row_scaled: bool, - with_2d_quantization: bool, - nvfp4_e4m3_max: int, -) -> torch.Tensor: - quantizer = NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=False, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - row_scaled_nvfp4=row_scaled, - nvfp4_use_4over6=True, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=err_mode, - ) - with _error_fast_math(err_fast_math): - quantized = quantizer(x) - assert quantized._rowwise_scale_inv is not None - return quantized._rowwise_scale_inv.contiguous().view(torch.uint8) - - -def _compare_e4m3( - x: torch.Tensor, - dtype_name: str, - scale_mode: str, - row_scaled: bool, - quant_mode: str, - with_2d_quantization: bool, - nvfp4_e4m3_max: int, -) -> None: - for err_mode in ("MAE", "MSE"): - regular = _quantize_scale_bytes( - x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max - ) - fast = _quantize_scale_bytes( - x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max - ) - same = torch.count_nonzero(regular == fast).item() - total = regular.numel() - print( - f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} " - f"{dtype_name:>5} {err_mode:>3} " - f"{100.0 * same / total:12.6f} {total - same:15d} {total}" - ) - - -def main(): - torch.set_grad_enabled(False) - print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}") - print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3") - for scale_mode, row_scaled, quant_mode, with_2d_quantization in ( - ("tensor", False, "1d", False), - ("tensor", False, "2d", True), - ("row", True, "1d", False), - ): - for nvfp4_e4m3_max in (256, 448): - for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")): - torch.manual_seed(1234) - x = torch.randn((M, K), dtype=dtype, device="cuda") - _compare_e4m3( - x, - dtype_name, - scale_mode, - row_scaled, - quant_mode, - with_2d_quantization, - nvfp4_e4m3_max, - ) - del x - torch.cuda.empty_cache() - - -if __name__ == "__main__": - main() diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 50776a3ed6..e5b53207f1 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -105,11 +105,6 @@ struct ScalePair { float global_encode_scale; }; -struct FP16ErrorScalePair { - uint32_t map4; - uint32_t map6; -}; - template __device__ __forceinline__ float compute_error_rn(const float diff) { if constexpr (kMode == kNVTENVFP44Over6MinMSE) { @@ -205,50 +200,42 @@ __device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { return *reinterpret_cast(&sf); } -__device__ __forceinline__ FP16ErrorScalePair compute_fp16_error_scales(const ScalePair &scales) { - FP16ErrorScalePair result; - const uint32_t packed_scales = static_cast(fp8_bits(scales.map4)) | - (static_cast(fp8_bits(scales.map6)) << 8); - asm volatile( - "{\n" - ".reg .b16 fp8_pair;\n" - ".reg .b16 map4_h, map6_h;\n" - ".reg .b32 scale_h2;\n" - "cvt.u16.u32 fp8_pair, %2;\n" - "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" - "mov.b32 {map4_h, map6_h}, scale_h2;\n" - "mov.b32 %0, {map4_h, map4_h};\n" - "mov.b32 %1, {map6_h, map6_h};\n" - "}" - : "=r"(result.map4), "=r"(result.map6) - : "r"(packed_scales)); - return result; -} - -__device__ __forceinline__ float2 f16x2_scaled_to_float2(const uint32_t q_h2, - const uint32_t scale_h2) { +__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte, + const nvfp4_scale_t sf) { float2 result; + const uint32_t sf_byte = static_cast(fp8_bits(sf)); asm volatile( "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + ".reg .b16 fp8_pair;\n" + ".reg .b16 scale_h, unused_h;\n" ".reg .b16 lo, hi;\n" + ".reg .b32 q_h2;\n" + ".reg .b32 scale_h2;\n" ".reg .b32 prod_h2;\n" - "mul.rn.f16x2 prod_h2, %2, %3;\n" + "mov.b32 {byte0, byte1, byte2, byte3}, %2;\n" + "cvt.rn.f16x2.e2m1x2 q_h2, byte0;\n" + "cvt.u16.u32 fp8_pair, %3;\n" + "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" + "mov.b32 {scale_h, unused_h}, scale_h2;\n" + "mov.b32 scale_h2, {scale_h, scale_h};\n" + "mul.rn.f16x2 prod_h2, q_h2, scale_h2;\n" "mov.b32 {lo, hi}, prod_h2;\n" "cvt.f32.f16 %0, lo;\n" "cvt.f32.f16 %1, hi;\n" "}" : "=f"(result.x), "=f"(result.y) - : "r"(q_h2), "r"(scale_h2)); + : "r"(e2m1_byte), "r"(sf_byte)); return result; } template -__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t q_h2, +__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t e2m1_byte, const float x0, const float x1, - const uint32_t scale_h2, + const nvfp4_scale_t sf, const float global_encode_scale, float *err) { - const float2 candidate = f16x2_scaled_to_float2(q_h2, scale_h2); + const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf); const float original0 = __fmul_rn(x0, global_encode_scale); const float original1 = __fmul_rn(x1, global_encode_scale); const float diff0 = __fsub_rn(candidate.x, original0); @@ -260,8 +247,7 @@ __device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf, - const uint32_t fp16_error_scale, const float global_amax, const float global_encode_scale, - float *err) { + const float global_amax, const float global_encode_scale, float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -296,14 +282,13 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( } if constexpr (Cfg::err_use_fast_math) { - accumulate_fp16_scaled_error_pair(out_dequant_1, x[0], x[1], fp16_error_scale, - global_encode_scale, err); - accumulate_fp16_scaled_error_pair(out_dequant_2, x[2], x[3], fp16_error_scale, - global_encode_scale, err); - accumulate_fp16_scaled_error_pair(out_dequant_3, x[4], x[5], fp16_error_scale, - global_encode_scale, err); - accumulate_fp16_scaled_error_pair(out_dequant_4, x[6], x[7], fp16_error_scale, - global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); + accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, + err); + accumulate_fp16_scaled_error_pair((out >> 16) & 0xFFu, x[4], x[5], sf, global_encode_scale, + err); + accumulate_fp16_scaled_error_pair((out >> 24) & 0xFFu, x[6], x[7], sf, global_encode_scale, + err); } else { const float sf_float = static_cast(sf); accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); @@ -325,22 +310,18 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c CandidatePair candidates; candidates.map4.err = 0.0f; candidates.map6.err = 0.0f; - FP16ErrorScalePair fp16_error_scales{}; - if constexpr (Cfg::err_use_fast_math) { - fp16_error_scales = compute_fp16_error_scales(scales); - } candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, - scales.global_encode_scale, &candidates.map4.err); + x0, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, + &candidates.map4.err); candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, - scales.global_encode_scale, &candidates.map6.err); + x0, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, + &candidates.map6.err); candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, - scales.global_encode_scale, &candidates.map4.err); + x1, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, + &candidates.map4.err); candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, - scales.global_encode_scale, &candidates.map6.err); + x1, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, + &candidates.map6.err); return candidates; } From beaed676dd0a3742a4b7a119f466001f9c0a1ba3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 2 Jun 2026 17:17:56 -0700 Subject: [PATCH 10/11] Preserve explicit common instruction lifting Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 237 ++++++++++++++++++ .../linear/compare_nvfp4_4over6_selection.py | 111 ++++++++ .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 89 ++++--- 3 files changed, 402 insertions(+), 35 deletions(-) create mode 100644 benchmarks/benchmark_4over6.py create mode 100644 benchmarks/linear/compare_nvfp4_4over6_selection.py diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py new file mode 100644 index 0000000000..48ba5cd403 --- /dev/null +++ b/benchmarks/benchmark_4over6.py @@ -0,0 +1,237 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark direct NVFP4 4over6 quantization kernel paths.""" + +import argparse +import os + +import torch +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.quantization import check_fp8_block_scaling_support +import transformer_engine_torch as tex + + +BENCHMARK_SHAPES = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), +] +PROFILE_SHAPES = [(16384, 6144)] + + +# Nsight Compute profiling command: +# ncu -f -o nvfp4_4over6 --set=full --profile-from-start off --target-processes all \ +# --kernel-name "quantize_4over6_kernel" \ +# python3 benchmarks/benchmark_4over6.py --profile --profile-repeats 10 + + +def make_quantizer(use_2d_quantization: bool, use_4over6: bool, err_mode: str) -> NVFP4Quantizer: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=use_2d_quantization, + stochastic_rounding=False, + row_scaled_nvfp4=False, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=448, + nvfp4_4over6_err_mode=err_mode, + with_random_sign_mask=True, + ) + + +def set_err_fast_math(enabled: bool) -> None: + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + + +def benchmark_quantize( + shape: tuple[int, int], + use_2d_quantization: bool, + use_4over6: bool, + err_mode: str, + err_fast_math: bool, + warmup: int, + iters: int, +) -> float: + set_err_fast_math(err_fast_math) + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) + out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) + + for _ in range(warmup): + quantizer.update_quantized(x, out) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + quantizer.update_quantized(x, out) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1000.0 / iters + + +def iter_cases(shapes): + for shape in shapes: + for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): + yield shape, mode_name, "nvfp4", "MAE", False, use_2d_quantization, False + for err_mode in ("MAE", "MSE"): + for err_fast_math in (False, True): + yield ( + shape, + mode_name, + "4over6", + err_mode, + err_fast_math, + use_2d_quantization, + True, + ) + + +def prepare_profile_case(case): + shape, mode_name, kernel, err_mode, err_fast_math, use_2d_quantization, use_4over6 = case + set_err_fast_math(err_fast_math) + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) + out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) + quantizer.update_quantized(x, out) + torch.cuda.synchronize() + return { + "shape": shape, + "mode_name": mode_name, + "kernel": kernel, + "err_mode": err_mode, + "err_fast_math": err_fast_math, + "quantizer": quantizer, + "x": x, + "out": out, + } + + +def run_profile(profile_repeats: int) -> None: + cases = [prepare_profile_case(case) for case in iter_cases(PROFILE_SHAPES)] + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + for case in cases: + set_err_fast_math(case["err_fast_math"]) + label = ( + f"shape={case['shape']} mode={case['mode_name']} kernel={case['kernel']} " + f"err={case['err_mode']} err_fast={case['err_fast_math']}" + ) + print(f"PROFILE {label}", flush=True) + torch.cuda.nvtx.range_push(label) + for _ in range(profile_repeats): + case["quantizer"].update_quantized(case["x"], case["out"]) + torch.cuda.nvtx.range_pop() + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStop() + + +def run_benchmark(shapes, warmup: int, iters: int) -> None: + rows = [] + for shape in shapes: + for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): + baseline_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=False, + err_mode="MAE", + err_fast_math=False, + warmup=warmup, + iters=iters, + ) + rows.append((shape, mode_name, "nvfp4", "-", baseline_us, 1.0, None, None)) + + for err_mode in ("MAE", "MSE"): + strict_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=True, + err_mode=err_mode, + err_fast_math=False, + warmup=warmup, + iters=iters, + ) + fast_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=True, + err_mode=err_mode, + err_fast_math=True, + warmup=warmup, + iters=iters, + ) + rows.append( + ( + shape, + mode_name, + "4over6", + err_mode, + strict_us, + strict_us / baseline_us, + fast_us, + fast_us / baseline_us, + ) + ) + + print( + f"{'shape':>18} {'mode':>4} {'kernel':>7} {'err':>3} " + f"{'strict_us':>10} {'strict':>8} {'fast_us':>10} {'fast':>8}" + ) + for ( + shape, + mode_name, + kernel, + err_mode, + strict_us, + strict_slowdown, + fast_us, + fast_slowdown, + ) in rows: + fast_us_str = "-" if fast_us is None else f"{fast_us:10.3f}" + fast_slowdown_str = "-" if fast_slowdown is None else f"{fast_slowdown:8.3f}x" + print( + f"{str(shape):>18} {mode_name:>4} {kernel:>7} {err_mode:>3} " + f"{strict_us:10.3f} {strict_slowdown:8.3f}x " + f"{fast_us_str:>10} {fast_slowdown_str:>8}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable Nsight Compute profile mode") + parser.add_argument("--profile-repeats", default=1, type=int) + parser.add_argument("--shapes", choices=("profile", "all"), default="profile") + parser.add_argument("--warmup", default=20, type=int) + parser.add_argument("--iters", default=1000, type=int) + args = parser.parse_args() + + supported, reason = check_fp8_block_scaling_support() + assert supported, reason + shapes = PROFILE_SHAPES if args.shapes == "profile" else BENCHMARK_SHAPES + if args.profile: + run_profile(args.profile_repeats) + else: + run_benchmark(shapes, args.warmup, args.iters) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/linear/compare_nvfp4_4over6_selection.py b/benchmarks/linear/compare_nvfp4_4over6_selection.py new file mode 100644 index 0000000000..7f73dfce9e --- /dev/null +++ b/benchmarks/linear/compare_nvfp4_4over6_selection.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Compare NVFP4 4over6 E4M3 scales with and without error fast math.""" + +import os +from contextlib import contextmanager + +import torch +from transformer_engine.pytorch import NVFP4Quantizer +import transformer_engine_torch as tex + + +M, K = 98304, 7168 + + +@contextmanager +def _error_fast_math(enabled: bool): + old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + try: + yield + finally: + if old_value is None: + os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) + else: + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value + + +def _quantize_scale_bytes( + x: torch.Tensor, + err_mode: str, + err_fast_math: bool, + row_scaled: bool, + with_2d_quantization: bool, + nvfp4_e4m3_max: int, +) -> torch.Tensor: + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled, + nvfp4_use_4over6=True, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=err_mode, + ) + with _error_fast_math(err_fast_math): + quantized = quantizer(x) + assert quantized._rowwise_scale_inv is not None + return quantized._rowwise_scale_inv.contiguous().view(torch.uint8) + + +def _compare_e4m3( + x: torch.Tensor, + dtype_name: str, + scale_mode: str, + row_scaled: bool, + quant_mode: str, + with_2d_quantization: bool, + nvfp4_e4m3_max: int, +) -> None: + for err_mode in ("MAE", "MSE"): + regular = _quantize_scale_bytes( + x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max + ) + fast = _quantize_scale_bytes( + x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max + ) + same = torch.count_nonzero(regular == fast).item() + total = regular.numel() + print( + f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} " + f"{dtype_name:>5} {err_mode:>3} " + f"{100.0 * same / total:12.6f} {total - same:15d} {total}" + ) + + +def main(): + torch.set_grad_enabled(False) + print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}") + print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3") + for scale_mode, row_scaled, quant_mode, with_2d_quantization in ( + ("tensor", False, "1d", False), + ("tensor", False, "2d", True), + ("row", True, "1d", False), + ): + for nvfp4_e4m3_max in (256, 448): + for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")): + torch.manual_seed(1234) + x = torch.randn((M, K), dtype=dtype, device="cuda") + _compare_e4m3( + x, + dtype_name, + scale_mode, + row_scaled, + quant_mode, + with_2d_quantization, + nvfp4_e4m3_max, + ) + del x + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index e5b53207f1..50776a3ed6 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -105,6 +105,11 @@ struct ScalePair { float global_encode_scale; }; +struct FP16ErrorScalePair { + uint32_t map4; + uint32_t map6; +}; + template __device__ __forceinline__ float compute_error_rn(const float diff) { if constexpr (kMode == kNVTENVFP44Over6MinMSE) { @@ -200,42 +205,50 @@ __device__ __forceinline__ uint8_t fp8_bits(const nvfp4_scale_t sf) { return *reinterpret_cast(&sf); } -__device__ __forceinline__ float2 e2m1x2_scaled_e4m3_to_float2(const uint32_t e2m1_byte, - const nvfp4_scale_t sf) { - float2 result; - const uint32_t sf_byte = static_cast(fp8_bits(sf)); +__device__ __forceinline__ FP16ErrorScalePair compute_fp16_error_scales(const ScalePair &scales) { + FP16ErrorScalePair result; + const uint32_t packed_scales = static_cast(fp8_bits(scales.map4)) | + (static_cast(fp8_bits(scales.map6)) << 8); asm volatile( "{\n" - ".reg .b8 byte0, byte1, byte2, byte3;\n" ".reg .b16 fp8_pair;\n" - ".reg .b16 scale_h, unused_h;\n" - ".reg .b16 lo, hi;\n" - ".reg .b32 q_h2;\n" + ".reg .b16 map4_h, map6_h;\n" ".reg .b32 scale_h2;\n" - ".reg .b32 prod_h2;\n" - "mov.b32 {byte0, byte1, byte2, byte3}, %2;\n" - "cvt.rn.f16x2.e2m1x2 q_h2, byte0;\n" - "cvt.u16.u32 fp8_pair, %3;\n" + "cvt.u16.u32 fp8_pair, %2;\n" "cvt.rn.f16x2.e4m3x2 scale_h2, fp8_pair;\n" - "mov.b32 {scale_h, unused_h}, scale_h2;\n" - "mov.b32 scale_h2, {scale_h, scale_h};\n" - "mul.rn.f16x2 prod_h2, q_h2, scale_h2;\n" + "mov.b32 {map4_h, map6_h}, scale_h2;\n" + "mov.b32 %0, {map4_h, map4_h};\n" + "mov.b32 %1, {map6_h, map6_h};\n" + "}" + : "=r"(result.map4), "=r"(result.map6) + : "r"(packed_scales)); + return result; +} + +__device__ __forceinline__ float2 f16x2_scaled_to_float2(const uint32_t q_h2, + const uint32_t scale_h2) { + float2 result; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + ".reg .b32 prod_h2;\n" + "mul.rn.f16x2 prod_h2, %2, %3;\n" "mov.b32 {lo, hi}, prod_h2;\n" "cvt.f32.f16 %0, lo;\n" "cvt.f32.f16 %1, hi;\n" "}" : "=f"(result.x), "=f"(result.y) - : "r"(e2m1_byte), "r"(sf_byte)); + : "r"(q_h2), "r"(scale_h2)); return result; } template -__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t e2m1_byte, +__device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t q_h2, const float x0, const float x1, - const nvfp4_scale_t sf, + const uint32_t scale_h2, const float global_encode_scale, float *err) { - const float2 candidate = e2m1x2_scaled_e4m3_to_float2(e2m1_byte, sf); + const float2 candidate = f16x2_scaled_to_float2(q_h2, scale_h2); const float original0 = __fmul_rn(x0, global_encode_scale); const float original1 = __fmul_rn(x1, global_encode_scale); const float diff0 = __fsub_rn(candidate.x, original0); @@ -247,7 +260,8 @@ __device__ __forceinline__ void accumulate_fp16_scaled_error_pair(const uint32_t template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t sf, - const float global_amax, const float global_encode_scale, float *err) { + const uint32_t fp16_error_scale, const float global_amax, const float global_encode_scale, + float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -282,13 +296,14 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error( } if constexpr (Cfg::err_use_fast_math) { - accumulate_fp16_scaled_error_pair(out & 0xFFu, x[0], x[1], sf, global_encode_scale, err); - accumulate_fp16_scaled_error_pair((out >> 8) & 0xFFu, x[2], x[3], sf, global_encode_scale, - err); - accumulate_fp16_scaled_error_pair((out >> 16) & 0xFFu, x[4], x[5], sf, global_encode_scale, - err); - accumulate_fp16_scaled_error_pair((out >> 24) & 0xFFu, x[6], x[7], sf, global_encode_scale, - err); + accumulate_fp16_scaled_error_pair(out_dequant_1, x[0], x[1], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_2, x[2], x[3], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_3, x[4], x[5], fp16_error_scale, + global_encode_scale, err); + accumulate_fp16_scaled_error_pair(out_dequant_4, x[6], x[7], fp16_error_scale, + global_encode_scale, err); } else { const float sf_float = static_cast(sf); accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); @@ -310,18 +325,22 @@ __device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], c CandidatePair candidates; candidates.map4.err = 0.0f; candidates.map6.err = 0.0f; + FP16ErrorScalePair fp16_error_scales{}; + if constexpr (Cfg::err_use_fast_math) { + fp16_error_scales = compute_fp16_error_scales(scales); + } candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, - &candidates.map4.err); + x0, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( - x0, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, - &candidates.map6.err); + x0, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map4, scales.map4, global_amax, scales.global_encode_scale, - &candidates.map4.err); + x1, scales.inv_map4, scales.map4, fp16_error_scales.map4, global_amax, + scales.global_encode_scale, &candidates.map4.err); candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( - x1, scales.inv_map6, scales.map6, global_amax, scales.global_encode_scale, - &candidates.map6.err); + x1, scales.inv_map6, scales.map6, fp16_error_scales.map6, global_amax, + scales.global_encode_scale, &candidates.map6.err); return candidates; } From a7d1a6ce3a915e11f82abae1b4cd8854097b1161 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 2 Jun 2026 20:18:58 -0700 Subject: [PATCH 11/11] Drop scripts Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 237 ------------------ .../linear/compare_nvfp4_4over6_selection.py | 111 -------- 2 files changed, 348 deletions(-) delete mode 100644 benchmarks/benchmark_4over6.py delete mode 100644 benchmarks/linear/compare_nvfp4_4over6_selection.py diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py deleted file mode 100644 index 48ba5cd403..0000000000 --- a/benchmarks/benchmark_4over6.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Benchmark direct NVFP4 4over6 quantization kernel paths.""" - -import argparse -import os - -import torch -from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.quantization import check_fp8_block_scaling_support -import transformer_engine_torch as tex - - -BENCHMARK_SHAPES = [ - (8192, 5120), - (8192, 10240), - (8192, 2560), - (8192, 11328), - (8192, 512), - (8192, 3584), - (5120, 8192), - (10240, 8192), - (2560, 8192), - (11328, 8192), - (512, 8192), - (3584, 8192), - (4096, 16384), - (14336, 16384), -] -PROFILE_SHAPES = [(16384, 6144)] - - -# Nsight Compute profiling command: -# ncu -f -o nvfp4_4over6 --set=full --profile-from-start off --target-processes all \ -# --kernel-name "quantize_4over6_kernel" \ -# python3 benchmarks/benchmark_4over6.py --profile --profile-repeats 10 - - -def make_quantizer(use_2d_quantization: bool, use_4over6: bool, err_mode: str) -> NVFP4Quantizer: - return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=True, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=use_2d_quantization, - stochastic_rounding=False, - row_scaled_nvfp4=False, - nvfp4_use_4over6=use_4over6, - nvfp4_e4m3_max=448, - nvfp4_4over6_err_mode=err_mode, - with_random_sign_mask=True, - ) - - -def set_err_fast_math(enabled: bool) -> None: - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" - - -def benchmark_quantize( - shape: tuple[int, int], - use_2d_quantization: bool, - use_4over6: bool, - err_mode: str, - err_fast_math: bool, - warmup: int, - iters: int, -) -> float: - set_err_fast_math(err_fast_math) - x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") - quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) - out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) - - for _ in range(warmup): - quantizer.update_quantized(x, out) - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - quantizer.update_quantized(x, out) - end.record() - torch.cuda.synchronize() - return start.elapsed_time(end) * 1000.0 / iters - - -def iter_cases(shapes): - for shape in shapes: - for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): - yield shape, mode_name, "nvfp4", "MAE", False, use_2d_quantization, False - for err_mode in ("MAE", "MSE"): - for err_fast_math in (False, True): - yield ( - shape, - mode_name, - "4over6", - err_mode, - err_fast_math, - use_2d_quantization, - True, - ) - - -def prepare_profile_case(case): - shape, mode_name, kernel, err_mode, err_fast_math, use_2d_quantization, use_4over6 = case - set_err_fast_math(err_fast_math) - x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") - quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) - out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) - quantizer.update_quantized(x, out) - torch.cuda.synchronize() - return { - "shape": shape, - "mode_name": mode_name, - "kernel": kernel, - "err_mode": err_mode, - "err_fast_math": err_fast_math, - "quantizer": quantizer, - "x": x, - "out": out, - } - - -def run_profile(profile_repeats: int) -> None: - cases = [prepare_profile_case(case) for case in iter_cases(PROFILE_SHAPES)] - torch.cuda.synchronize() - torch.cuda.cudart().cudaProfilerStart() - for case in cases: - set_err_fast_math(case["err_fast_math"]) - label = ( - f"shape={case['shape']} mode={case['mode_name']} kernel={case['kernel']} " - f"err={case['err_mode']} err_fast={case['err_fast_math']}" - ) - print(f"PROFILE {label}", flush=True) - torch.cuda.nvtx.range_push(label) - for _ in range(profile_repeats): - case["quantizer"].update_quantized(case["x"], case["out"]) - torch.cuda.nvtx.range_pop() - torch.cuda.synchronize() - torch.cuda.cudart().cudaProfilerStop() - - -def run_benchmark(shapes, warmup: int, iters: int) -> None: - rows = [] - for shape in shapes: - for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): - baseline_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=False, - err_mode="MAE", - err_fast_math=False, - warmup=warmup, - iters=iters, - ) - rows.append((shape, mode_name, "nvfp4", "-", baseline_us, 1.0, None, None)) - - for err_mode in ("MAE", "MSE"): - strict_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=True, - err_mode=err_mode, - err_fast_math=False, - warmup=warmup, - iters=iters, - ) - fast_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=True, - err_mode=err_mode, - err_fast_math=True, - warmup=warmup, - iters=iters, - ) - rows.append( - ( - shape, - mode_name, - "4over6", - err_mode, - strict_us, - strict_us / baseline_us, - fast_us, - fast_us / baseline_us, - ) - ) - - print( - f"{'shape':>18} {'mode':>4} {'kernel':>7} {'err':>3} " - f"{'strict_us':>10} {'strict':>8} {'fast_us':>10} {'fast':>8}" - ) - for ( - shape, - mode_name, - kernel, - err_mode, - strict_us, - strict_slowdown, - fast_us, - fast_slowdown, - ) in rows: - fast_us_str = "-" if fast_us is None else f"{fast_us:10.3f}" - fast_slowdown_str = "-" if fast_slowdown is None else f"{fast_slowdown:8.3f}x" - print( - f"{str(shape):>18} {mode_name:>4} {kernel:>7} {err_mode:>3} " - f"{strict_us:10.3f} {strict_slowdown:8.3f}x " - f"{fast_us_str:>10} {fast_slowdown_str:>8}" - ) - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--profile", action="store_true", help="Enable Nsight Compute profile mode") - parser.add_argument("--profile-repeats", default=1, type=int) - parser.add_argument("--shapes", choices=("profile", "all"), default="profile") - parser.add_argument("--warmup", default=20, type=int) - parser.add_argument("--iters", default=1000, type=int) - args = parser.parse_args() - - supported, reason = check_fp8_block_scaling_support() - assert supported, reason - shapes = PROFILE_SHAPES if args.shapes == "profile" else BENCHMARK_SHAPES - if args.profile: - run_profile(args.profile_repeats) - else: - run_benchmark(shapes, args.warmup, args.iters) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/linear/compare_nvfp4_4over6_selection.py b/benchmarks/linear/compare_nvfp4_4over6_selection.py deleted file mode 100644 index 7f73dfce9e..0000000000 --- a/benchmarks/linear/compare_nvfp4_4over6_selection.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Compare NVFP4 4over6 E4M3 scales with and without error fast math.""" - -import os -from contextlib import contextmanager - -import torch -from transformer_engine.pytorch import NVFP4Quantizer -import transformer_engine_torch as tex - - -M, K = 98304, 7168 - - -@contextmanager -def _error_fast_math(enabled: bool): - old_value = os.environ.get("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH") - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" - try: - yield - finally: - if old_value is None: - os.environ.pop("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH", None) - else: - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = old_value - - -def _quantize_scale_bytes( - x: torch.Tensor, - err_mode: str, - err_fast_math: bool, - row_scaled: bool, - with_2d_quantization: bool, - nvfp4_e4m3_max: int, -) -> torch.Tensor: - quantizer = NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=False, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - row_scaled_nvfp4=row_scaled, - nvfp4_use_4over6=True, - nvfp4_e4m3_max=nvfp4_e4m3_max, - nvfp4_4over6_err_mode=err_mode, - ) - with _error_fast_math(err_fast_math): - quantized = quantizer(x) - assert quantized._rowwise_scale_inv is not None - return quantized._rowwise_scale_inv.contiguous().view(torch.uint8) - - -def _compare_e4m3( - x: torch.Tensor, - dtype_name: str, - scale_mode: str, - row_scaled: bool, - quant_mode: str, - with_2d_quantization: bool, - nvfp4_e4m3_max: int, -) -> None: - for err_mode in ("MAE", "MSE"): - regular = _quantize_scale_bytes( - x, err_mode, False, row_scaled, with_2d_quantization, nvfp4_e4m3_max - ) - fast = _quantize_scale_bytes( - x, err_mode, True, row_scaled, with_2d_quantization, nvfp4_e4m3_max - ) - same = torch.count_nonzero(regular == fast).item() - total = regular.numel() - print( - f"{scale_mode:>6} {quant_mode:>5} {nvfp4_e4m3_max:8d} " - f"{dtype_name:>5} {err_mode:>3} " - f"{100.0 * same / total:12.6f} {total - same:15d} {total}" - ) - - -def main(): - torch.set_grad_enabled(False) - print(f"shape=({M}, {K}), 1d_e4m3_values={M * K // 16}") - print("scale quant e4m3_max dtype mode same_e4m3_pct different_e4m3 total_e4m3") - for scale_mode, row_scaled, quant_mode, with_2d_quantization in ( - ("tensor", False, "1d", False), - ("tensor", False, "2d", True), - ("row", True, "1d", False), - ): - for nvfp4_e4m3_max in (256, 448): - for dtype, dtype_name in ((torch.bfloat16, "bf16"), (torch.float16, "fp16")): - torch.manual_seed(1234) - x = torch.randn((M, K), dtype=dtype, device="cuda") - _compare_e4m3( - x, - dtype_name, - scale_mode, - row_scaled, - quant_mode, - with_2d_quantization, - nvfp4_e4m3_max, - ) - del x - torch.cuda.empty_cache() - - -if __name__ == "__main__": - main()