Skip to content

[PyTorch] Debug CPU offloading in grouped linear and grouped MLP#3047

Merged
timmoon10 merged 7 commits into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser
Jun 6, 2026
Merged

[PyTorch] Debug CPU offloading in grouped linear and grouped MLP#3047
timmoon10 merged 7 commits into
NVIDIA:mainfrom
lhb8125:feat/selective-offload-on-srelu-fuser

Conversation

@lhb8125
Copy link
Copy Markdown
Contributor

@lhb8125 lhb8125 commented May 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from dba3531 to 53e6511 Compare May 27, 2026 07:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 27, 2026

Greptile Summary

This PR wires CPU activation offloading into the fused grouped-MLP forward path and into the standalone GroupedLinear op, and adds the get_data_tensors() method to GroupedTensorStorage that the V1 offload path requires. The core insight is that GroupedTensor is a torch.Tensor subclass, so the offload infrastructure's prepare_for_saving would treat it as a flat tensor; the fix is to ensure activations are represented as GroupedTensorStorage before start_offload is called.

  • forward_grouped_mlp.py: Sets fc1_input_quantizer.internal = True so the MXFP8/NVFP4 fast-path produces a GroupedTensorStorage, then repacks the pre-quantized GroupedTensor input path into a GroupedTensorStorage copy. After forward, calls start_offload/mark_activation_offload on grouped_fc1_x, activation_in, and saved_grouped_fc2_x.
  • grouped_linear.py: Adds start_offload for grouped_x in the grouped-tensor forward path and start_offload + mark_activation_offload for split-xs in the split-quantize path; also marks grouped activations in fuser_forward_save_ctx.
  • grouped_tensor_storage.py: Adds get_data_tensors() returning the same 10 tensor fields exposed by prepare_for_saving, used by the V1 code path.

Confidence Score: 3/5

The fused-MLP offload path is largely correct, but the standalone GroupedLinear grouped-tensor path has an incomplete offload that silently fails to schedule columnwise activations for CPU transfer in the non-V1 path.

The _fuser_forward_grouped_tensor method in grouped_linear.py omits input_quantizer.internal = True before tex.group_quantize, so grouped_x is a GroupedTensor (torch.Tensor subclass). When start_offload(grouped_x) is called, prepare_for_saving routes through the plain-tensor branch and attaches start_reload_event only to the wrapper's rowwise storage — the same storage that is nulled out two lines later. The columnwise_data needed for the wgrad backward is never scheduled for offloading. This directly undermines the feature this PR is implementing for the non-V1 path in that code branch.

transformer_engine/pytorch/ops/basic/grouped_linear.py — the _fuser_forward_grouped_tensor method needs input_quantizer.internal = True before tex.group_quantize, mirroring what forward_grouped_mlp.py already does for fc1_input_quantizer.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds CPU offloading for fused grouped-MLP activations (fc1/fc2 inputs and activation_in). Repacks the MXFP8/NVFP4 fast-path input into GroupedTensorStorage to avoid prepare_for_saving treating GroupedTensor as a plain tensor; fc2_input_quantizer.internal is still not set to True, leaving grouped_fc2_x as a GroupedTensor in the NVFP4 path (previously flagged, not fixed here).
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds CPU offload hooks to both the split-quantize and grouped-tensor forward paths. In _fuser_forward_grouped_tensor (quantized path), tex.group_quantize returns GroupedTensor without internal=True, so start_offload only marks the rowwise wrapper — not columnwise_data or columnwise_scale_inv — for the non-V1 offload path.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Adds get_data_tensors() returning the same 10 fields as prepare_for_saving(). Required by the V1 cpu_offload path which calls tensor.get_data_tensors() for non-plain-tensor types. Implementation matches prepare_for_saving exactly.
transformer_engine/pytorch/ops/basic/basic_linear.py Adds an explanatory comment for the existing mark_activation_offload call, clarifying why no special weight offload logic is needed. No functional change.
transformer_engine/pytorch/module/grouped_linear.py Mechanical replacement of GroupedTensor with GroupedTensorStorage in two helper methods (_wrap_grouped_tensor and _pack_grouped_bias). Return types updated to match.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant ForwardGroupedMLP
    participant GroupedLinear
    participant cpu_offload

    Caller->>ForwardGroupedMLP: fuser_forward(input_)
    ForwardGroupedMLP->>ForwardGroupedMLP: "fc1_input_quantizer.internal = True"
    ForwardGroupedMLP->>ForwardGroupedMLP: group_quantize → GroupedTensorStorage (fc1)
    ForwardGroupedMLP->>ForwardGroupedMLP: FC1 GEMM + activation
    ForwardGroupedMLP->>ForwardGroupedMLP: group_quantize → grouped_fc2_x (GroupedTensor, NVFP4)
    ForwardGroupedMLP->>ForwardGroupedMLP: FC2 GEMM
    ForwardGroupedMLP->>cpu_offload: start_offload(grouped_fc1_x, activation_in, grouped_fc2_x)
    cpu_offload->>cpu_offload: prepare_for_saving → decompose GroupedTensorStorage
    cpu_offload->>cpu_offload: mark component tensors with start_reload_event
    cpu_offload->>cpu_offload: restore_from_saved
    ForwardGroupedMLP->>cpu_offload: mark_activation_offload(...) [V1 path only]
    cpu_offload->>cpu_offload: "get_data_tensors() → set activation_offloading=True"
    ForwardGroupedMLP->>ForwardGroupedMLP: save_for_backward(grouped_fc1_x, grouped_fc2_x, ...)
    Note over GroupedLinear: standalone (non-fused) path
    Caller->>GroupedLinear: _fuser_forward_grouped_tensor
    GroupedLinear->>GroupedLinear: "tex.group_quantize → GroupedTensor (no internal=True!)"
    GroupedLinear->>cpu_offload: start_offload(grouped_x)
    cpu_offload->>cpu_offload: "isinstance(GroupedTensor, Tensor)=True → plain-tensor branch"
    cpu_offload->>cpu_offload: start_reload_event on wrapper only (columnwise_data NOT marked)
Loading

Reviews (14): Last reviewed commit: "Construct internal grouped tensors withi..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 2c59510 to 6e01d0a Compare May 27, 2026 07:49
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but with one design suggestion.

Followup tasks after merging this PR:

  • Enable activation checkpointing in the unfused grouped linear op.
  • Update activation checkpointing to support v2 infrastructure from #1762, which is opt-out rather than opt-in.

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
fanshiqing added a commit to fanshiqing/TransformerEngine that referenced this pull request Jun 2, 2026
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 6f5ef0a to c14deb7 Compare June 2, 2026 10:58
@lhb8125
Copy link
Copy Markdown
Contributor Author

lhb8125 commented Jun 2, 2026

/te-ci pytorch L1

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from c14deb7 to f25a1f5 Compare June 2, 2026 11:09
Comment on lines +680 to +681
no_offload_fc1_activation = bool(getattr(fc1_op, "no_offload_activation", False))
no_offload_moe_activation = bool(getattr(activation_op, "no_offload_activation", False))
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine, but we should avoid relying on these in Mcore. They should be considered as debugging tools and we don't make guarantees to maintain them.

If we do want to control this in Mcore, we should do it through proper, user-facing APIs. We can add it as an arg to the unfused ops, but we should also make sure the unfused ops actually respect it too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I also thought about the same thing, but didn't want to add too many changes. Now I add offload_activation for all related modules, now the consistence would be better.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that this is ballooning the scope of this PR and adds a maintenance burden to so many ops. It just makes it that much more painful to implement a new fusion. I've reverted the new arg so that this PR just fixes the offloading bug in grouped MLP.

For a more general solution to selective offloading, I wonder if we can expand the existing APIs for CPU offloading. If you want to explicitly disable in one layer, what if we allowed nesting like this:

with get_cpu_offload_context(...):
    x = layer0(x)  # CPU offloading
    with get_cpu_offload_context(enabled=False, ...):
        x = layer1(x)  # No CPU offloading
    x = layer2(x)  # CPU offloading

This would work automatically, with no extra logic needed in every op.

and isinstance(input_quantizer, NVFP4Quantizer)
):
grouped_fc1_x = input_
grouped_fc1_x = GroupedTensorStorage(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we set the input quantizer with .internal = True, isn't it redundant to repack grouped_fc1_x into a GroupedTensorStorage?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are targeting different cases. The input quantizer with .interal= True```` takes effects on bf16 input, where we need to quantize it by fc1_input_quantizer. The second case is that the input is already a quantized fp8 tensor, where we need to repack it into a GroupedTensorStorage```.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could it already be quantized? The only way to create it is from the quantizer, either just now or from a previous step (e.g. with activation recompute). If the quantizer has .internal=True, it can only be GroupedTensorStorage.

If something is incorrectly producing GroupedTensor, then that's a bug. Fixing it here is papering over the real problem.

Actually, on second thought, it makes sense that input_ can be a GroupedTensor since it comes from outside the op. It would be useful to know what use-case hit this bug though. Activation recompute?

Really the root cause is that CPU offloading doesn't handle GroupedTensor gracefully, but that would be a more involved effort.

lhb8125 added 3 commits June 5, 2026 02:54
Signed-off-by: hongbinl <hongbinl@nvidia.com>
Signed-off-by: hongbinl <hongbinl@nvidia.com>
Signed-off-by: hongbinl <hongbinl@nvidia.com>
@lhb8125 lhb8125 force-pushed the feat/selective-offload-on-srelu-fuser branch from 07f4c97 to 933d64b Compare June 5, 2026 09:55
- Revert per-module offload_activation API added in commits 376d28c
  and 933d64b; that belongs in a separate PR.
- ops/basic/grouped_linear: add start_offload on input tensors before
  the GEMM, and mark_activation_offload / mark_not_offload in
  fuser_forward_save_ctx for both the split-quantize and grouped-tensor
  paths.
- ops/fused/forward_grouped_mlp: remove no_offload_activation attribute
  lookups and the activation mark_not_offload calls that gated on them;
  add start_offload + mark_activation_offload for all saved activation
  tensors (grouped_fc1_x, activation_in, saved_grouped_fc2_x) and keep
  mark_not_offload only for weight tensors. Document why grouped_fc1_x
  is repacked into GroupedTensorStorage.
- ops/basic/basic_linear: no change needed beyond the existing
  mark_activation_offload — unlike te.Linear there is no persistent
  weight cache, so the quantized weight workspace can be freely
  offloaded.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member

/te-ci pytorch

pre-commit-ci Bot and others added 3 commits June 5, 2026 23:39
GroupedTensor should only be used when exposed externally. Otherwise GroupedTensorStorage has less CPU overhead. There also seems to be some issue with CPU offloading that has not yet been root-caused.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 changed the title Feat/selective offload on srelu fuser Debug CPU offloading in grouped linear and grouped MLP Jun 6, 2026
@timmoon10 timmoon10 changed the title Debug CPU offloading in grouped linear and grouped MLP [PyTorch] Debug CPU offloading in grouped linear and grouped MLP Jun 6, 2026
@timmoon10
Copy link
Copy Markdown
Member

/te-ci pytorch

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose merging this quickly to unblock CPU offloading with grouped MLP. Afterwards, we can work on a more general interface for selectively enabling/disabling in specific modules.

@timmoon10 timmoon10 merged commit 3fffa55 into NVIDIA:main Jun 6, 2026
23 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants