Skip to content

[JAX] Hopper BF16 grouped GEMM v2 support#3083

Draft
jberchtold-nvidia wants to merge 3 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/hopper-bf16-gmm
Draft

[JAX] Hopper BF16 grouped GEMM v2 support#3083
jberchtold-nvidia wants to merge 3 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/hopper-bf16-gmm

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Adds support for Hopper BF16 grouped GEMM. Also adds a pure-JAX bias. Fused bias to be implemented later

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Hopper BF16 grouped GEMM

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft June 4, 2026 20:48
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 4, 2026

Greptile Summary

This PR extends the V2 grouped GEMM path to support BF16 on SM90+ (Hopper), previously gated to SM100+ (Blackwell). It also adds a pure-JAX post-kernel bias addition (_add_grouped_gemm_bias) for V2 kernels, with fused bias deferred to a future PR.

  • SM90 BF16 V2 path: _is_v2_grouped_gemm_supported now returns True for NO_SCALING + bfloat16 on SM90+; MXFP8 retains its SM100+ gate. Alpha/beta tensors are shape (1,) on Hopper (scalar broadcast) and (num_gemms,) on Blackwell+, handled on both the Python side (_v2_grouped_gemm_supports_per_group_alpha_beta) and the C++ side (use element_count() instead of num_gemms).
  • Pure-JAX bias for V2: When use_v2_ffi=True and has_bias=True, the C++ kernel receives no bias (has_bias_for_ffi=False); bias is added after the kernel via jnp.repeat-based expansion for ragged-first-dim outputs or broadcasting reshape for static outputs.
  • Removed prior bias block: The old blanket if has_bias: return False guard in _is_v2_grouped_gemm_supported is dropped; the has_bias parameter is now only reflected in the terminal error message.

Confidence Score: 4/5

Safe to merge for the BF16 Hopper V2 path. The pure-JAX bias addition is straightforward; the ragged-last-dims combination with bias will raise NotImplementedError at runtime rather than silently producing wrong data.

The core routing and alpha/beta scalar-broadcast changes look correct. The main gap is that the expensive V2 kernel will be dispatched before the NotImplementedError for bias + ragged-last-dims is raised, wasting GPU time in that edge case. The missing cache on the capability helper and the now-dead has_bias parameter are minor quality issues that don't affect correctness.

transformer_engine/jax/cpp_extensions/gemm.py — specifically the interaction between _is_v2_grouped_gemm_supported, _add_grouped_gemm_bias, and the ragged-last-dims + bias edge case.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Extends V2 grouped GEMM to support BF16 on SM90+ (Hopper); adds pure-JAX post-kernel bias via _add_grouped_gemm_bias; relaxes alpha/beta to scalar (1,) on Hopper. Logic is correct but has minor gaps: missing cache on capability helper, no upstream gate for bias+ragged-last-dims before kernel dispatch, and a now-dead has_bias parameter in the support-check function.
transformer_engine/jax/csrc/extensions/gemm.cpp Changes TensorWrapper construction for alpha and beta to use alpha.element_count() / beta.element_count() instead of the hardcoded num_gemms, allowing scalar (size-1) tensors on Hopper. Straightforward and correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[grouped_gemm called] --> B{_is_v2_grouped_gemm_supported?}
    B -- "NO_SCALING + BF16\nSM90+" --> C[use_v2_ffi = True]
    B -- "MXFP8\nSM100+" --> D{dim checks pass?}
    D -- Yes --> C
    D -- No --> E[use_v2_ffi = False\nV1 path]
    B -- "Other / SM < 90" --> E
    C --> F{SM >= 100?}
    F -- Yes --> G["alpha/beta shape (num_gemms,)"]
    F -- No --> H["alpha/beta shape (1,) — scalar"]
    G & H --> I[GroupedGemmPrimitive FFI\nhas_bias_for_ffi = False]
    E --> J[GroupedGemmPrimitive FFI\nhas_bias_for_ffi = has_bias]
    I --> K{has_bias?}
    K -- Yes --> L{out_last_dims\nis not None?}
    L -- Yes --> M[NotImplementedError]
    L -- No --> N{out_first_dims\nis not None?}
    N -- Yes --> O[jnp.repeat bias rows\nragged path]
    N -- No --> P[broadcast reshape\nstatic path]
    K -- No --> Q[return out]
    O & P --> Q
    J --> Q
Loading

Comments Outside Diff (1)

  1. transformer_engine/jax/cpp_extensions/gemm.py, line 2104-2112 (link)

    P2 has_bias parameter is now unused in conditional logic

    has_bias is accepted by _is_v2_grouped_gemm_supported (and forwarded from the public wrapper) but no longer gates any conditional branch — it only appears in the terminal error message. Callers reading the signature will expect this flag to influence the routing decision; leaving it as a dead parameter (rather than removing it or adding a comment) is likely to confuse future readers who trace why a has-bias call still returns True.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +2099 to +2101
def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool:
"""Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices."""
return get_min_device_compute_capability() >= 100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing cache on capability check

_v2_grouped_gemm_supports_per_group_alpha_beta() calls get_min_device_compute_capability() on every grouped_gemm invocation, but unlike _should_enforce_v2_grouped_gemm() (which is decorated with @cache) it has no memoization. Both functions encode a process-wide constant; querying CUDA device capability in a hot path can add unnecessary overhead. Adding @cache (or @functools.lru_cache(maxsize=None)) mirrors the pattern already used by the sibling helper.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +2421 to +2422
if out_last_dims is not None:
raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Runtime error surfaces after kernel execution for bias + ragged-last-dims

When has_bias=True and out_last_dims is not None, _is_v2_grouped_gemm_supported still returns True for BF16 on SM90+ (no gate for this combination), so the full V2 kernel is dispatched before _add_grouped_gemm_bias raises NotImplementedError. The check should be moved upstream — either into _is_v2_grouped_gemm_supported (returning False to fall back to V1) or as an early guard in grouped_gemm before the FFI bind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant