Skip to content

[JAX] MoEBlock tutorial#3084

Draft
jberchtold-nvidia wants to merge 30 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/moe-tutorial
Draft

[JAX] MoEBlock tutorial#3084
jberchtold-nvidia wants to merge 30 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/moe-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng and others added 22 commits June 3, 2026 11:02
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
[JAX] MoE: soft re-pin inbound activations sharding at moe() entry
[JAX] MoE: scope gate_logits 2D reshape to topk primitive call
[JAX] MoE: add apply_topk_weights_early flag (TE EP backend only)
[JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat

Signed-off-by: tdophung <tdophung@nvidia.com>
…y step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks

Signed-off-by: tdophung <tdophung@nvidia.com>
…per-call

ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under
jax.jit and fails with TracerArrayConversionError. Move the bootstrap to
the test fixture (matching the test_multi_process_ep.py pattern from the
TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls
record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule
now only asserts that the recorded bootstrap signature is wide enough
(num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank
and recv_capacity_per_rank <= bootstrap values). Test mesh fixture
bootstraps with the worst-case recv_pr across _CONFIGS so every
parametrized config is compatible with a single per-process bootstrap.
The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in
e927903) expects an EpHandle object plus a separate handle_mem buffer
for every dispatch/combine call. The MoE wrapper was still passing the
raw slots_per_expert int as the second positional and unpacking
ep_dispatch_fwd as a 3-tuple, which now blows up with
"AttributeError: 'int' object has no attribute 'handle_id'".

Changes:
- Cache one EpHandle per (top_k, alignment) at module scope so repeated
  jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool.
- _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle)
  -> (token_counts, handle_mem), and pass (handle, handle_mem) into the
  fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple.
- _Ctx: stash handle_mem alongside handle so the bwd can hand both back
  to ep_combine_bwd and ep_dispatch_bwd.
- _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine
  calls.
te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap.
Tests dispatch bf16 tokens; without this arg the group lands with the
legacy kByte default (1 byte) and every dispatch aborts at the
ep_backend.cpp:349 dtype check.
…ill fix for real in later commits

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft June 4, 2026 21:38
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 4, 2026

Greptile Summary

This PR adds a new JAX tutorial for Mixture-of-Experts with TransformerEngine's experimental _MoEBlock, including a native JAX/Flax BF16 baseline (moe_native.py), the TE integration example (moe.py), RST documentation (moe.rst), pre-captured benchmark output (moe.out), and pytest tests (test_moe.py). The hub page (te_jax_integration.rst) is updated to link to the new tutorial.

  • The core tutorial demonstrates a 2×2 EP/FSDP mesh setup, correctness comparison between native and TE paths, and a forward+backward timing benchmark showing ~1.43× TE speedup on GB200 hardware.
  • A helper function _block_until_ready_tree is introduced for benchmark synchronization but only waits on leaves[0] rather than the whole tree; the same single-leaf wait pattern is repeated inline in setup_demo and test_moe.py, and the main() entrypoint silently skips the native-only benchmark when TE is unavailable.

Confidence Score: 4/5

The tutorial is safe to merge; all issues are non-blocking style and consistency concerns with no impact on correctness of the documented TE path.

The core benchmark helper _block_until_ready_tree only synchronizes the first leaf of the output tree, and the same single-leaf pattern is repeated in setup_demo and in the test. Functionally, JAX/XLA makes all outputs of a single JIT dispatch available together, so benchmarks are not mismeasured today, but the implementation contradicts its name and is misleading tutorial code. Separately, main() bails out before running the native baseline when TE is unsupported, leaving users on pre-Blackwell hardware with no benchmark output at all. The _te_moe_available function in test_moe.py duplicates the capability check from moe.py, creating a maintenance risk if the threshold changes.

docs/examples/jax/moe.py_block_until_ready_tree, setup_demo sync call, and main() early-return logic; docs/examples/jax/test_moe.py — duplicated capability check and sync pattern.

Important Files Changed

Filename Overview
docs/examples/jax/moe.py Main tutorial script; introduces _block_until_ready_tree that only waits on leaves[0], and main() exits early when TE is unsupported, skipping the native baseline entirely.
docs/examples/jax/moe_native.py Native JAX/Flax MoE baseline; ragged all-to-all + grouped GEMMs are implemented correctly; no significant issues found.
docs/examples/jax/test_moe.py Pytest tests with appropriate requires_4gpu / requires_te_moe skip guards; duplicates capability-check logic from moe.py and repeats the single-leaf block_until_ready pattern.
docs/examples/jax/moe.rst RST tutorial document; literalinclude markers align with moe.py; benchmark tables and correctness numbers are well documented.
docs/examples/jax/moe.out Pre-captured benchmark output used by literalinclude; no code issues.
docs/examples/te_jax_integration.rst Hub page updated to link to the new MoE tutorial; minimal change, correct placement.

Sequence Diagram

sequenceDiagram
    participant User
    participant moe.py
    participant NativeMoEBlock
    participant TEMoEBlock
    participant JAX/XLA

    User->>moe.py: python moe.py
    moe.py->>moe.py: build_ep_fsdp_mesh()
    moe.py->>NativeMoEBlock: native_model.init(k_init, x)
    NativeMoEBlock-->>moe.py: variables (gate_kernel, wi_0, wi_1, wo)
    moe.py->>moe.py: shard_inputs_and_variables()

    Note over moe.py,JAX/XLA: compare_forward()
    moe.py->>JAX/XLA: jit(native_model.apply)(variables, x)
    JAX/XLA-->>moe.py: native_out
    moe.py->>JAX/XLA: jit(te_apply)(variables, x)
    JAX/XLA-->>moe.py: te_out
    moe.py->>User: "max |native - TE| = 0.0604"

    Note over moe.py,JAX/XLA: run_benchmarks() — fwd+bwd timing
    loop warmup + timing iters
        moe.py->>JAX/XLA: jit(value_and_grad(loss_fn))(vars, x, dy)
        JAX/XLA-->>moe.py: (loss, (grad_vars, grad_x))
        moe.py->>moe.py: _block_until_ready_tree(result)
    end
    moe.py->>User: Mean time (native), Mean time (TE)
Loading

Comments Outside Diff (1)

  1. docs/examples/jax/test_moe.py, line 975-991 (link)

    P2 Duplicated capability-check logic

    _te_moe_available() here is an exact copy of te_moe_supported() in moe.py. If the Blackwell compute-capability threshold, the module path, or the attribute name ever changes, both copies need to be updated in sync. Consider importing te_moe_supported from moe directly (since moe is already imported in the TE-guarded tests) or moving the shared check into a small helper module.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread docs/examples/jax/moe.py Outdated
Comment thread docs/examples/jax/moe.py Outdated
Comment thread docs/examples/jax/moe.py
Comment on lines +301 to +308
return

demo = setup_demo()
compare_forward(demo)
run_benchmarks(demo)


if __name__ == "__main__":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 main() returns immediately when TE is unavailable (non-Blackwell hardware or missing build), so the native baseline never runs either. Users who just want to validate the native JAX path — or who are running on a pre-Blackwell GPU — see a silent skip with no output. Separating the TE guard from the native baseline makes the script more useful as a standalone tutorial.

Suggested change
return
demo = setup_demo()
compare_forward(demo)
run_benchmarks(demo)
if __name__ == "__main__":
te_supported, te_reason = te_moe_supported()
if not te_supported:
print(f"[skipped TE comparison: {te_reason}]")
demo = setup_demo()
if te_supported:
compare_forward(demo)
run_benchmarks(demo)

Comment thread docs/examples/jax/test_moe.py Outdated
jberchtold-nvidia and others added 4 commits June 4, 2026 16:09
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 3 commits June 5, 2026 10:24
…rmerEngine into jberchtold/moe-tutorial

# Conflicts:
#	build_tools/jax.py
#	tests/cpp_distributed/CMakeLists.txt
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants