Skip to content

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093

Open
cael-ling wants to merge 2 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache
Open

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093
cael-ling wants to merge 2 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache

Conversation

@cael-ling
Copy link
Copy Markdown
Contributor

Description

For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside general_gemm on every micro-batch and thrown away, so with N micro-batches the weight scale swizzle runs 2*N times per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already set optimize_for_gemm=True and were pre-swizzled; only the weight was missed.)

This PR sets weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time, persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM — 2*N2 swizzles per step.

  • Applied to Linear, LayerNormLinear, LayerNormMLP (fc1 + fc2) and GroupedLinear (per expert).

  • Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there.

  • No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).

  • Swizzling is a pure layout permutation, so numerics are unchanged.

  • New tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py: asserts the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear / LayerNormLinear / GroupedLinear, and that _with_gemm_swizzled_scales is set and persisted on the cached workspace.

  • pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp" — no regressions.

  • pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path" — no regressions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…obatches

For block-scaled NVFP4 a cached weight participates in two GEMMs per step:
fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale
swizzle was recomputed lazily inside every GEMM and discarded, so with N
microbatches the weight scale swizzle ran 2*N times per step even though the
weight is quantized only once.

Because weight RHT is disabled, the weight scales are not swizzled by the
cast-fusion path; with optimize_for_gemm off they also skip the post-quantize
fallback swizzle, so the only swizzle site left for the weight is the lazy one
inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM.
(Activation input/grad_output quantizers already set optimize_for_gemm=True, so
they were pre-swizzled via cast-fusion/fallback; only the weight was missed.)

Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the
swizzle is done once at quantize time (via the post-quantize fallback),
persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused
by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step
instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and
GroupedLinear (per expert).

Gated to the cached path (is_first_microbatch is not None) with fsdp_group is
None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled
scale layout, so pre-swizzling is unsupported there. No-op for recipes whose
scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure
layout permutation, so numerics are unchanged.

Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached
eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for
Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner June 5, 2026 14:29
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR eliminates redundant GEMM-swizzle operations for cached NVFP4 block-scaled weights by setting weight_quantizer.optimize_for_gemm = True on the cached, non-FSDP forward path in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear, reducing 2*N swizzle kernels per step to 2.

  • Four module files each receive a new block that sets optimize_for_gemm based on whether a weight cache name is active (cache_name is not None) and FSDP is absent (fsdp_group is None); the is_fsdp2 exclusion is handled implicitly through the cache_name derivation in linear.py/layernorm_*.py and explicitly via not self.is_fsdp2 in grouped_linear.py — both are correct.
  • New test file test_nvfp4_weight_swizzle_cache.py verifies fprop/dgrad numerical parity and _with_gemm_swizzled_scales flag persistence for Linear, LayerNormLinear, and GroupedLinear, but LayerNormMLP (which has a structurally distinct fc1/fc2 two-quantizer path) is not exercised.

Confidence Score: 4/5

The four module changes are a clean, consistent gating of optimize_for_gemm on the cached non-FSDP path; the only gap is that LayerNormMLP's fc1/fc2 two-quantizer path is not covered by the new tests.

The optimization logic is correct and consistent across all four modules — FSDP1 and FSDP2 exclusions are properly handled. The new test file exercises three of the four modified code paths and verifies both numerical parity and flag persistence. LayerNormMLP is modified in a structurally distinct way (two independent quantizers) but has no corresponding test, leaving a gap that could hide a future regression in that path.

tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py — missing LayerNormMLP test case

Important Files Changed

Filename Overview
tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py New test file covering fprop/dgrad numerical parity and swizzle-flag persistence; misses LayerNormMLP despite it being one of the four modified modules
transformer_engine/pytorch/module/linear.py Correctly gates optimize_for_gemm=True on cache_name is not None (excludes is_fsdp2 implicitly) and fsdp_group is None
transformer_engine/pytorch/module/layernorm_linear.py Same gating pattern as linear.py; is_fsdp2 excluded via cache_name=None, fsdp_group exclusion explicit — correct
transformer_engine/pytorch/module/layernorm_mlp.py fc1/fc2 quantizers each independently gated; logic matches linear.py pattern and is correct, but lacks a corresponding test
transformer_engine/pytorch/module/grouped_linear.py Explicit not self.is_fsdp2 guard added alongside fsdp_group is None; consistent with other modules' effective gating despite a different cache_weight derivation

Sequence Diagram

sequenceDiagram
    participant T as Training Loop
    participant M as TE Module
    participant Q as WeightQuantizer
    participant G as general_gemm

    Note over T,G: Cached path (is_first_microbatch is not None, no FSDP)
    T->>M: "forward(x, is_first_microbatch=True)"
    M->>Q: "set optimize_for_gemm = True"
    M->>Q: quantize(weight) → swizzle scales eagerly
    Q-->>M: FP4Tensor (swizzled scales cached in workspace)
    M->>G: fprop GEMM (cached swizzled rowwise scales)
    G-->>M: output y
    M->>G: dgrad GEMM (cached swizzled columnwise scales)
    G-->>M: dx
    T->>M: "forward(x2, is_first_microbatch=False)"
    M->>Q: "set optimize_for_gemm = True"
    Note over M,Q: Weight already cached — skip requantize
    M->>G: fprop GEMM (reuse cached swizzled scales)
    G-->>M: output y2
    M->>G: dgrad GEMM (reuse cached swizzled scales)
    G-->>M: dx2

    Note over T,G: Uncached path (is_first_microbatch=None) or FSDP
    T->>M: "forward(x, is_first_microbatch=None)"
    M->>Q: "set optimize_for_gemm = False"
    M->>G: fprop GEMM (lazy swizzle rowwise inside GEMM)
    G-->>M: output y
    M->>G: dgrad GEMM (lazy swizzle columnwise inside GEMM)
    G-->>M: dx
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +67 to +72
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)
out.sum().backward()
return out.detach().float(), x.grad.detach().float()


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing LayerNormMLP test coverage

layernorm_mlp.py is one of four files modified by this PR, yet the test suite parametrizes only over ["Linear", "LayerNormLinear"] for both test_weight_swizzle_cache_numerics and test_lazy_path_not_swizzled. The fc1/fc2 two-quantizer path in LayerNormMLP is structurally different from the single-quantizer modules: it independently gates fc1_weight_quantizer.optimize_for_gemm and fc2_weight_quantizer.optimize_for_gemm using separate cache_name_fc1/cache_name_fc2 variables. If either gating expression were wrong (e.g. swapping fc1/fc2 names), existing tests would not catch it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant