From e7a4db99c68258d3a9b150311ce000a60afa5548 Mon Sep 17 00:00:00 2001 From: Min Yang Date: Wed, 20 May 2026 00:06:02 +0800 Subject: [PATCH] feat: add varlen k grouped gemm support Signed-off-by: Min Yang --- .../common/gemm/cublaslt_gemm.cu | 43 ++++ .../common/gemm/cutlass_grouped_gemm.cu | 102 +++++++++ .../common/gemm/cutlass_grouped_gemm.cuh | 214 ++++++++++++++++++ 3 files changed, 359 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index e59e9c00c9..93a3998faa 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1110,6 +1110,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); }; + auto is_bf16_wgrad_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputD->data.dtype); + + return (A_type == CUDA_R_16BF) && (B_type == CUDA_R_16BF) && + (D_type == CUDA_R_32F || D_type == CUDA_R_16BF); + }; + + // K-grouped BF16 wgrad shape eligibility: every group must be 2D NT with a matching + // (ragged) K and a uniform hidden/expert. Shapes outside this fall back to cuBLAS + // instead of hard-erroring inside the varlen-k kernel. + auto is_bf16_wgrad_shape = [&]() -> bool { + int64_t ref_hidden = -1, ref_expert = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto *inp = transformer_engine::convertNVTETensorCheck(A[i]); + const auto *grad = transformer_engine::convertNVTETensorCheck(B[i]); + if (inp->data.shape.size() != 2 || grad->data.shape.size() != 2) return false; + const int64_t k = inp->data.shape[0]; + const int64_t hidden = inp->data.shape[1]; + const int64_t expert = grad->data.shape[1]; + if (static_cast(grad->data.shape[0]) != k || hidden <= 0 || expert <= 0) + return false; + if (ref_hidden < 0) { + ref_hidden = hidden; + ref_expert = expert; + } else if (hidden != ref_hidden || expert != ref_expert) { + return false; + } + } + return true; + }; + // CUTLASS Grouped GEMM fast path (SM90/TMA) // Conditions: // - No fused epilogue: both bias and pre_gelu_out are empty. @@ -1123,6 +1159,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, current_device, math_sm_count, stream); + } else if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_bf16_wgrad_dtype() && !transa && + transb && grad && is_bf16_wgrad_shape()) { + // Dedicated K-grouped (ragged-K) BF16-in / (FP32 or BF16)-out wgrad path: + // D_i = B_i.T @ A_i, K_i = routed-token dim. Shape eligibility is guarded above, so + // unsupported shapes fall back to cuBLAS rather than hard-erroring in the kernel. + cutlass_grouped_gemm_varlen_k(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); } else { if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu index ef720d1984..d04eac98e6 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu @@ -4,6 +4,12 @@ * See LICENSE for license information. **************************************************************************************************/ +#include +#include + +#include +#include + #include "cutlass/bfloat16.h" #include "cutlass/cutlass.h" #include "cutlass_grouped_gemm.cuh" @@ -36,6 +42,18 @@ template void CutlassGroupedGemm(const NVTETen NVTETensor*, float, float, int, cudaStream_t, int, int); +// Explicit instantiation: BF16-in / FP32-out (default) wgrad path. +template void CutlassGroupedGemmWgrad(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, float, + int, cudaStream_t, int, int); + +// Explicit instantiation: BF16-in / BF16-out wgrad path. +template void CutlassGroupedGemmWgrad(const NVTETensor*, + const NVTETensor*, + NVTETensor*, NVTETensor*, + float, float, int, + cudaStream_t, int, int); + } // namespace grouped_gemm } // namespace transformer_engine @@ -75,3 +93,87 @@ void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); } } + +namespace { + +// Zero-initialize empty (K=0) groups (when not accumulating) and forward the non-empty groups to +// CUTLASS. Precondition: the dispatcher (nvte_multi_tensor_gemm) has already validated the BF16 NT +// wgrad contract -- 2D, matching ragged K, uniform hidden/expert, BF16-in / (FP32|BF16)-out -- via +// is_bf16_wgrad_dtype() + is_bf16_wgrad_shape(), so it is not re-checked here. +void collect_bf16_wgrad_nt_groups(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + int num_gemms, bool accumulate, cudaStream_t stream, + std::vector* A_nz, std::vector* B_nz, + std::vector* D_nz, + transformer_engine::DType* out_dtype) { + using namespace transformer_engine; + // hidden/expert/output-dtype are uniform across groups; read them once from group 0. + const int64_t hidden = convertNVTETensorCheck(A[0])->data.shape[1]; + const int64_t expert = convertNVTETensorCheck(B[0])->data.shape[1]; + *out_dtype = convertNVTETensorCheck(D[0])->data.dtype; + const size_t elem = (*out_dtype == DType::kFloat32) ? sizeof(float) : sizeof(__nv_bfloat16); + + for (int i = 0; i < num_gemms; ++i) { + if (convertNVTETensorCheck(A[i])->data.shape[0] == 0) { + // Empty group: its null A/B pointers would crash TMA descriptor construction, so zero the + // output (when not accumulating) and exclude it from the launch. + auto* out = convertNVTETensorCheck(D[i]); + if (!accumulate && out->data.dptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(out->data.dptr, 0, + static_cast(expert) * hidden * elem, stream)); + } + } else { + A_nz->push_back(A[i]); + B_nz->push_back(B[i]); + D_nz->push_back(D[i]); + } + } +} + +} // namespace + +void cutlass_grouped_gemm_varlen_k(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + int num_gemms, bool transa, bool transb, bool grad, + NVTETensor* workspace, bool accumulate, int device, + int math_sm_count, cudaStream_t stream) { + using namespace transformer_engine; + // The kernel hard-codes the NT layout, so assert it: a wrong-layout caller would otherwise + // mis-compute silently. (Arch / no-epilogue / group-0 dtype are already gated by the + // dispatcher, and a wrong arch would fail loudly inside the CUTLASS kernel anyway.) + NVTE_CHECK(!transa && transb && grad, + "cutlass_grouped_gemm_varlen_k requires NT wgrad layout " + "(transa=false, transb=true, grad=true)."); + NVTE_CHECK(workspace != nullptr, "cutlass_grouped_gemm_varlen_k requires a non-null workspace."); + + std::vector A_nz, B_nz, D_nz; + A_nz.reserve(num_gemms); + B_nz.reserve(num_gemms); + D_nz.reserve(num_gemms); + DType out_dtype = DType::kFloat32; + collect_bf16_wgrad_nt_groups(A, B, D, num_gemms, accumulate, stream, &A_nz, &B_nz, &D_nz, + &out_dtype); + + // All groups have K=0: outputs are already zero-initialized above, nothing to launch. + if (A_nz.empty()) return; + + const int n_nz = static_cast(A_nz.size()); + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + // NT wgrad: D_i = B_i^T @ A_i. Pass grad_output (outer B) as CUTLASS A (trans_a=true) + // and input (outer A) as CUTLASS B (trans_b=false). CutlassGroupedGemmWgrad validates + // the workspace size internally. + auto dispatch = [&](auto tag) { + using T = decltype(tag); + grouped_gemm::CutlassGroupedGemmWgrad(B_nz.data(), A_nz.data(), D_nz.data(), + workspace, alpha, beta, n_nz, stream, + device, math_sm_count); + }; + + if (out_dtype == DType::kFloat32) { + dispatch(float{}); + } else { + dispatch(cutlass::bfloat16_t{}); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh index aa2bde4203..c64e325d87 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -340,9 +340,223 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, } } +template +struct GemmGivenScheduleWgrad; + +// Base config shared by both FP32 and BF16 output specialisations. +// Subclasses override TileShape / ClusterShape / KernelSchedule / EpilogueSchedule. +template +struct GemmGivenScheduleWgradBase { + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = ElementD; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + // TMA minimum 16 B: 8×BF16 or 4×FP32. + static constexpr int AlignmentA = 8; + static constexpr int AlignmentB = 8; + static constexpr int AlignmentC = static_cast(16 / sizeof(ElementD)); +}; + +// FP32 output: Cooperative 128×128×64, ClusterShape 1×1×1. +// Two warpgroups keep both the MMA pipeline and the FP32 epilogue busy. +template +struct GemmGivenScheduleWgrad + : GemmGivenScheduleWgradBase { + using Base = GemmGivenScheduleWgradBase; + using ElementD = float; + using ElementC = float; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using LayoutA = typename Base::LayoutA; + using LayoutB = typename Base::LayoutB; + using LayoutC = typename Base::LayoutC; + static constexpr int AlignmentA = Base::AlignmentA; + static constexpr int AlignmentB = Base::AlignmentB; + static constexpr int AlignmentC = Base::AlignmentC; + + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementD, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, typename Base::ElementA, LayoutA*, AlignmentA, + typename Base::ElementB, LayoutB*, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// BF16-output specialization: TileShape 128x128x128, ClusterShape 1x2x1, Ptr-Array TMA +// warp-specialized Pingpong schedule (SM90). The 8-element (kWgradMinAlign) alignment on the +// expert/hidden dims is validated before launch; any remaining tile/shape constraints are +// enforced by the kernel's can_implement check inside CutlassGroupedGemmWgrad. +template +struct GemmGivenScheduleWgrad + : GemmGivenScheduleWgradBase { + using Base = GemmGivenScheduleWgradBase; + using ElementD = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using LayoutA = typename Base::LayoutA; + using LayoutB = typename Base::LayoutB; + using LayoutC = typename Base::LayoutC; + static constexpr int AlignmentA = Base::AlignmentA; + static constexpr int AlignmentB = Base::AlignmentB; + static constexpr int AlignmentC = Base::AlignmentC; + + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementD, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, typename Base::ElementA, LayoutA*, AlignmentA, + typename Base::ElementB, LayoutB*, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +using GemmGroupedWgrad = typename GemmGivenScheduleWgrad::Gemm; + +template +void CutlassGroupedGemmWgrad(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Config = GemmGivenScheduleWgrad; + using Gemm = GemmGroupedWgrad; + using LayoutA = typename Config::LayoutA; + using LayoutB = typename Config::LayoutB; + using LayoutC = typename Config::LayoutC; + using ElementA = typename Config::ElementA; + using ElementB = typename Config::ElementB; + using ElementC = typename Config::ElementC; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + const size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + const auto gemm_coord_size = getGemmCoordSize(num_gemms); + const auto ptr_size = getPtrSize(num_gemms); + const auto ldd_size = getLddSize(num_gemms); + const auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK(param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize for wgrad grouped GEMM: required=", + static_cast(param_workspace_size)); + + const auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), + "Insufficient workspace[0] for wgrad grouped GEMM: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel())); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + char* host_workspace = getHostWorkspace(); + + auto* problem_sizes_host = reinterpret_cast(host_workspace); + auto* ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + auto* ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + auto* ptr_C_host = reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + auto* lda_host = reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size); + auto* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + ldd_size); + auto* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (int i = 0; i < num_gemms; i++) { + const auto* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const auto* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + auto* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = + trans_a ? static_cast(inputA->data.shape[1]) : static_cast(inputA->data.shape[0]); + const int k = + trans_a ? static_cast(inputA->data.shape[0]) : static_cast(inputA->data.shape[1]); + const int n = + trans_b ? static_cast(inputB->data.shape[0]) : static_cast(inputB->data.shape[1]); + + problem_sizes_host[i] = ProblemShapeType(m, n, k); + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + auto* problem_sizes_device = reinterpret_cast(workspace_ptr); + const ElementA** ptr_A = reinterpret_cast(workspace_ptr + gemm_coord_size); + const ElementB** ptr_B = + reinterpret_cast(workspace_ptr + gemm_coord_size + ptr_size); + ElementC** ptr_C = reinterpret_cast(workspace_ptr + gemm_coord_size + 2 * ptr_size); + auto* lda = reinterpret_cast(workspace_ptr + gemm_coord_size + 3 * ptr_size); + auto* ldb = reinterpret_cast(workspace_ptr + gemm_coord_size + 3 * ptr_size + ldd_size); + auto* ldc = + reinterpret_cast(workspace_ptr + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + char* kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_ERROR("Wgrad grouped GEMM: can_implement check failed (", num_gemms, " groups)"); + } + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_ERROR("Wgrad grouped GEMM: initialize failed (", num_gemms, " groups)"); + } + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_ERROR("Wgrad grouped GEMM: run failed (", num_gemms, " groups)"); + } +} + } // namespace grouped_gemm } // namespace transformer_engine void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, bool transa, bool transb, bool grad, NVTETensor* workspace, bool accumulate, int device, int math_sm_count, cudaStream_t stream); + +void cutlass_grouped_gemm_varlen_k(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + int num_gemms, bool transa, bool transb, bool grad, + NVTETensor* workspace, bool accumulate, int device, + int math_sm_count, cudaStream_t stream);