Skip to content

[PyTorch] Add joint forward-backward op fusion pass#3080

Open
timmoon10 wants to merge 1 commit into
NVIDIA:mainfrom
timmoon10:tmoon/joint-forward-backward-fusions
Open

[PyTorch] Add joint forward-backward op fusion pass#3080
timmoon10 wants to merge 1 commit into
NVIDIA:mainfrom
timmoon10:tmoon/joint-forward-backward-fusions

Conversation

@timmoon10
Copy link
Copy Markdown
Member

Description

The op fuser assumes that forward fusions and backward fusions can be applied independently, so we enforce a contract that fused ops are interchangeable with the corresponding unfused ops. However, in the process of developing the grouped MLP fused ops, we have identified several optimizations that only make sense when the forward and backward fusions are performed together (e.g. recomputing FC2 input tensors instead of caching for backward).

This PR adds infrastructure to support joint forward-backward fusions, which relax the interchangeability contract so the forward and backward passes can have coupled implementations. Refactoring the grouped MLP ops as such a joint fused op will be deferred to a follow-up PR.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add joint forward-backward op fusion pass

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Introduce a third operation fusion pass for joint forward-backward
fusions, applied before the forward-only and backward-only passes. A
joint fused op implements both fuser_forward and fuser_backward, so the
two halves can cooperate (e.g. the forward saving reduced state that
only its own backward knows how to recompute) and need not be
individually interchangeable with the unfused ops.

Add register_forward_backward_fusion, split the fusion application and
basic-op reconciliation in OperationFuser so the forward/backward passes
build on the joint grouping, and spell out the interchangeability
contracts in the register_*_fusion docstrings. Add a custom joint
fusion unit test and a user guide section.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested a review from vthumbe1503 June 4, 2026 00:27
@timmoon10 timmoon10 added the enhancement New feature or request label Jun 4, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 4, 2026

Greptile Summary

This PR adds a new "joint forward-backward" fusion pass to the OperationFuser pipeline in TransformerEngine's PyTorch ops framework. Joint fusions let a single FusedOperation own both fuser_forward and fuser_backward, relaxing the interchangeability contract so the forward pass can save reduced state that only its own backward knows how to consume.

  • Refactors _fuse_ops (classmethod) into two focused statics — _apply_fusions (applies a chain of fusion functions) and _match_basic_ops (maps results back to basic-op indices with an added bounds guard) — and inserts the joint-fusion pass before the existing forward-only and backward-only passes.
  • Adds register_forward_backward_fusion, exports it from ops/__init__.py, documents it in both the API reference and the user guide with a fully worked LinearSiLU example, and includes a comprehensive test (test_custom_forward_backward_fused_op) that verifies numerics and asserts the same fused-op object appears in both _forward_ops and _backward_ops.

Confidence Score: 4/5

Safe to merge; the production code path is correct and existing forward/backward fusion functions handle mixed op-type lists safely via isinstance checks.

The refactoring of _fuse_ops into _apply_fusions + _match_basic_ops is clean and the three-pass ordering is correct. The test's _enabled flag-based one-shot pattern is fragile (silent fusion loss on re-run) and the fuse_ops helper lacks defensive bounds/type checks, but neither affects the production implementation.

tests/pytorch/test_fusible_ops.py — the _enabled class-variable guard and bare indexing in fuse_ops are worth hardening before this test is extended.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Core change: _fuse_ops split into _apply_fusions + _match_basic_ops; joint fusion pass inserted before forward/backward passes; register_forward_backward_fusion added with clear docstring. Logic is sound.
tests/pytorch/test_fusible_ops.py Adds test_custom_forward_backward_fused_op; uses a class-level _enabled flag to restrict fusion to a single application, but the flag persists as False if maybe_fuse_ops is re-triggered (e.g. recipe change), which would silently drop the joint fusion on subsequent calls.
transformer_engine/pytorch/ops/op.py Docstring-only update to FusedOperation clarifying the equivalence contracts for forward-only, backward-only, and joint fused ops. No logic changes.
docs/examples/op_fuser/op_fuser.rst New Joint forward-backward fusions section added with a well-explained LinearSiLU worked example and a fuse_linear_silu sliding-window function. No issues found.
transformer_engine/pytorch/ops/init.py Adds register_forward_backward_fusion to the module's public exports alongside the two existing register functions. Alphabetically sorted.
docs/api/pytorch.rst Adds autoapifunction directive for register_forward_backward_fusion in the correct location.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant OperationFuser
    participant JointFuseFunc as register_forward_backward_fusion funcs
    participant FwdFuseFunc as register_forward_fusion funcs
    participant BwdFuseFunc as register_backward_fusion funcs
    participant AutogradFn as _OperationFuserAutogradFunction

    Caller->>OperationFuser: __call__(input, ...)
    OperationFuser->>OperationFuser: maybe_fuse_ops()

    Note over OperationFuser: Pass 1 — Joint forward-backward fusions
    OperationFuser->>JointFuseFunc: _apply_fusions(basic_ops, joint_funcs)
    JointFuseFunc-->>OperationFuser: joint_ops (may contain FusedOperation)

    Note over OperationFuser: Pass 2 — Forward-only fusions (on joint_ops)
    OperationFuser->>FwdFuseFunc: _apply_fusions(joint_ops, fwd_funcs)
    FwdFuseFunc-->>OperationFuser: fwd_result
    OperationFuser->>OperationFuser: _match_basic_ops(fwd_result, basic_ops) → _forward_ops

    Note over OperationFuser: Pass 3 — Backward-only fusions (on joint_ops)
    OperationFuser->>BwdFuseFunc: _apply_fusions(joint_ops, bwd_funcs)
    BwdFuseFunc-->>OperationFuser: bwd_result
    OperationFuser->>OperationFuser: _match_basic_ops(bwd_result, basic_ops) → _backward_ops

    Caller->>AutogradFn: apply(input, fuser, ...)
    AutogradFn->>AutogradFn: forward() — runs _forward_ops
    AutogradFn->>AutogradFn: backward() — runs _backward_ops (same joint op instance)
Loading

Reviews (1): Last reviewed commit: "[PyTorch] Add joint forward-backward op ..." | Re-trigger Greptile

Comment on lines +5185 to +5194
dsilu = s * (1 + y * (1 - s))
dy = dout * dsilu

# Linear backward
dx = torch.matmul(dy, w).to(dtype=dtype)
dw = torch.matmul(dy.T, x).to(dtype=dtype)

# grad_input, grad params per basic op, grad extra inputs per basic op
return dx, [(dw,), ()], [(), ()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _enabled flag silently drops the joint fusion on re-fuse

CustomLinearSiLU._enabled is set to False after the first call to fuse_ops, so if maybe_fuse_ops is triggered a second time (e.g. recipe type changes, first_op_requiring_backward shifts, or amax-history length changes), the fusion function returns the ops list unchanged. On that re-run _forward_ops and _backward_ops would revert to the two unfused basic ops, causing the assert isinstance(forward_ops[0][0], CustomLinearSiLU) assertions below to fail silently or with a confusing error rather than a clear "joint fusion was not reapplied" message. The current test is safe because it only calls the model once, but the pattern is fragile: any future extension that adds a second forward call (e.g., to test different recipe configurations) will break without an obvious explanation. Consider either resetting _enabled at the start of each maybe_fuse_ops-triggering call, or restructuring the fuse function to be idempotent (fuse only if the first op is not already a CustomLinearSiLU).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is quick and hacky. The right fix would be a way to unregister fusions, but that's outside the scope of this PR.

Comment on lines +5196 to +5204
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomLinearSiLU._enabled:
CustomLinearSiLU._enabled = False
op = CustomLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fuse_ops indexes ops without bounds or type guards

When _enabled is True, the function unconditionally accesses ops[0] and ops[1] and constructs CustomLinearSiLU(linear=ops[0], silu=ops[1]) without checking len(ops) >= 2 or that the ops are the expected types. If a future pipeline change reduces the number of basic ops to fewer than two (or the joint-fusion pass is called earlier in a different context), this raises an uncaught IndexError with no diagnostic message. Compared to the documented sliding-window pattern in op_fuser.rst — which uses isinstance checks before fusing — the test's fuse_ops skips these guards entirely, making it a less reliable reference for users adapting this code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is quick and hacky. This function prioritizes simplicity over robustness.

@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant