[JAX] Hopper BF16 grouped GEMM v2 support#3083
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
Greptile SummaryThis PR extends the V2 grouped GEMM path to support BF16 on SM90+ (Hopper), previously gated to SM100+ (Blackwell). It also adds a pure-JAX post-kernel bias addition (
Confidence Score: 4/5Safe to merge for the BF16 Hopper V2 path. The pure-JAX bias addition is straightforward; the ragged-last-dims combination with bias will raise NotImplementedError at runtime rather than silently producing wrong data. The core routing and alpha/beta scalar-broadcast changes look correct. The main gap is that the expensive V2 kernel will be dispatched before the NotImplementedError for bias + ragged-last-dims is raised, wasting GPU time in that edge case. The missing cache on the capability helper and the now-dead has_bias parameter are minor quality issues that don't affect correctness. transformer_engine/jax/cpp_extensions/gemm.py — specifically the interaction between _is_v2_grouped_gemm_supported, _add_grouped_gemm_bias, and the ragged-last-dims + bias edge case. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[grouped_gemm called] --> B{_is_v2_grouped_gemm_supported?}
B -- "NO_SCALING + BF16\nSM90+" --> C[use_v2_ffi = True]
B -- "MXFP8\nSM100+" --> D{dim checks pass?}
D -- Yes --> C
D -- No --> E[use_v2_ffi = False\nV1 path]
B -- "Other / SM < 90" --> E
C --> F{SM >= 100?}
F -- Yes --> G["alpha/beta shape (num_gemms,)"]
F -- No --> H["alpha/beta shape (1,) — scalar"]
G & H --> I[GroupedGemmPrimitive FFI\nhas_bias_for_ffi = False]
E --> J[GroupedGemmPrimitive FFI\nhas_bias_for_ffi = has_bias]
I --> K{has_bias?}
K -- Yes --> L{out_last_dims\nis not None?}
L -- Yes --> M[NotImplementedError]
L -- No --> N{out_first_dims\nis not None?}
N -- Yes --> O[jnp.repeat bias rows\nragged path]
N -- No --> P[broadcast reshape\nstatic path]
K -- No --> Q[return out]
O & P --> Q
J --> Q
|
| def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool: | ||
| """Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices.""" | ||
| return get_min_device_compute_capability() >= 100 |
There was a problem hiding this comment.
Missing cache on capability check
_v2_grouped_gemm_supports_per_group_alpha_beta() calls get_min_device_compute_capability() on every grouped_gemm invocation, but unlike _should_enforce_v2_grouped_gemm() (which is decorated with @cache) it has no memoization. Both functions encode a process-wide constant; querying CUDA device capability in a hot path can add unnecessary overhead. Adding @cache (or @functools.lru_cache(maxsize=None)) mirrors the pattern already used by the sibling helper.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if out_last_dims is not None: | ||
| raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims") |
There was a problem hiding this comment.
Runtime error surfaces after kernel execution for bias + ragged-last-dims
When has_bias=True and out_last_dims is not None, _is_v2_grouped_gemm_supported still returns True for BF16 on SM90+ (no gate for this combination), so the full V2 kernel is dispatched before _add_grouped_gemm_bias raises NotImplementedError. The check should be moved upstream — either into _is_v2_grouped_gemm_supported (returning False to fall back to V1) or as an early guard in grouped_gemm before the FFI bind.
Description
Adds support for Hopper BF16 grouped GEMM. Also adds a pure-JAX bias. Fused bias to be implemented later
Type of change
Changes
Checklist: