[JAX] MoEBlock tutorial#3084
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
[JAX] MoE: soft re-pin inbound activations sharding at moe() entry [JAX] MoE: scope gate_logits 2D reshape to topk primitive call [JAX] MoE: add apply_topk_weights_early flag (TE EP backend only) [JAX] MoE: stack wi_0/wi_1 on new axis (4D) instead of concat Signed-off-by: tdophung <tdophung@nvidia.com>
…y step paths. change tests to collapse in 1 bigger one with different parameters instead of smaller meaningless dtypes/shapes/finite chhecks Signed-off-by: tdophung <tdophung@nvidia.com>
…per-call ep_bootstrap allgathers a NCCL UID via the JAX runtime, which traces under jax.jit and fails with TracerArrayConversionError. Move the bootstrap to the test fixture (matching the test_multi_process_ep.py pattern from the TE EP JAX PR): caller invokes ep_bootstrap once per process, then calls record_ep_bootstrap_signature_for_moe with the same params. _moe_fwd_rule now only asserts that the recorded bootstrap signature is wide enough (num_experts/hidden_dim/ep_size exact match; per-call max_tokens_per_rank and recv_capacity_per_rank <= bootstrap values). Test mesh fixture bootstraps with the worst-case recv_pr across _CONFIGS so every parametrized config is compatible with a single per-process bootstrap.
The cpp_extensions/ep.py API (post the per-layer EpHandle refactor in e927903) expects an EpHandle object plus a separate handle_mem buffer for every dispatch/combine call. The MoE wrapper was still passing the raw slots_per_expert int as the second positional and unpacking ep_dispatch_fwd as a 3-tuple, which now blows up with "AttributeError: 'int' object has no attribute 'handle_id'". Changes: - Cache one EpHandle per (top_k, alignment) at module scope so repeated jit traces don't burn the NVTE_EP_HANDLE_CACHE_SIZE pool. - _moe_fwd_rule: mint/lookup the handle, call ep_prepare(topk_idx, handle) -> (token_counts, handle_mem), and pass (handle, handle_mem) into the fwd dispatch/combine calls. ep_dispatch_fwd now returns a 2-tuple. - _Ctx: stash handle_mem alongside handle so the bwd can hand both back to ep_combine_bwd and ep_dispatch_bwd. - _moe_bwd_rule: thread ctx.handle_mem into the bwd dispatch/combine calls.
te-ep-fixes plumbs NVTEEpGroupConfig.max_token_dtype through ep_bootstrap. Tests dispatch bf16 tokens; without this arg the group lands with the legacy kByte default (1 byte) and every dispatch aborts at the ep_backend.cpp:349 dtype check.
…ill fix for real in later commits Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a new JAX tutorial for Mixture-of-Experts with TransformerEngine's experimental
Confidence Score: 4/5The tutorial is safe to merge; all issues are non-blocking style and consistency concerns with no impact on correctness of the documented TE path. The core benchmark helper
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant moe.py
participant NativeMoEBlock
participant TEMoEBlock
participant JAX/XLA
User->>moe.py: python moe.py
moe.py->>moe.py: build_ep_fsdp_mesh()
moe.py->>NativeMoEBlock: native_model.init(k_init, x)
NativeMoEBlock-->>moe.py: variables (gate_kernel, wi_0, wi_1, wo)
moe.py->>moe.py: shard_inputs_and_variables()
Note over moe.py,JAX/XLA: compare_forward()
moe.py->>JAX/XLA: jit(native_model.apply)(variables, x)
JAX/XLA-->>moe.py: native_out
moe.py->>JAX/XLA: jit(te_apply)(variables, x)
JAX/XLA-->>moe.py: te_out
moe.py->>User: "max |native - TE| = 0.0604"
Note over moe.py,JAX/XLA: run_benchmarks() — fwd+bwd timing
loop warmup + timing iters
moe.py->>JAX/XLA: jit(value_and_grad(loss_fn))(vars, x, dy)
JAX/XLA-->>moe.py: (loss, (grad_vars, grad_x))
moe.py->>moe.py: _block_until_ready_tree(result)
end
moe.py->>User: Mean time (native), Mean time (TE)
|
| return | ||
|
|
||
| demo = setup_demo() | ||
| compare_forward(demo) | ||
| run_benchmarks(demo) | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
There was a problem hiding this comment.
main() returns immediately when TE is unavailable (non-Blackwell hardware or missing build), so the native baseline never runs either. Users who just want to validate the native JAX path — or who are running on a pre-Blackwell GPU — see a silent skip with no output. Separating the TE guard from the native baseline makes the script more useful as a standalone tutorial.
| return | |
| demo = setup_demo() | |
| compare_forward(demo) | |
| run_benchmarks(demo) | |
| if __name__ == "__main__": | |
| te_supported, te_reason = te_moe_supported() | |
| if not te_supported: | |
| print(f"[skipped TE comparison: {te_reason}]") | |
| demo = setup_demo() | |
| if te_supported: | |
| compare_forward(demo) | |
| run_benchmarks(demo) |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
…rmerEngine into jberchtold/moe-tutorial # Conflicts: # build_tools/jax.py # tests/cpp_distributed/CMakeLists.txt
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: