[PyTorch] Add joint forward-backward op fusion pass#3080
Conversation
Introduce a third operation fusion pass for joint forward-backward fusions, applied before the forward-only and backward-only passes. A joint fused op implements both fuser_forward and fuser_backward, so the two halves can cooperate (e.g. the forward saving reduced state that only its own backward knows how to recompute) and need not be individually interchangeable with the unfused ops. Add register_forward_backward_fusion, split the fusion application and basic-op reconciliation in OperationFuser so the forward/backward passes build on the joint grouping, and spell out the interchangeability contracts in the register_*_fusion docstrings. Add a custom joint fusion unit test and a user guide section. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryThis PR adds a new "joint forward-backward" fusion pass to the
Confidence Score: 4/5Safe to merge; the production code path is correct and existing forward/backward fusion functions handle mixed op-type lists safely via isinstance checks. The refactoring of _fuse_ops into _apply_fusions + _match_basic_ops is clean and the three-pass ordering is correct. The test's _enabled flag-based one-shot pattern is fragile (silent fusion loss on re-run) and the fuse_ops helper lacks defensive bounds/type checks, but neither affects the production implementation. tests/pytorch/test_fusible_ops.py — the _enabled class-variable guard and bare indexing in fuse_ops are worth hardening before this test is extended. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant OperationFuser
participant JointFuseFunc as register_forward_backward_fusion funcs
participant FwdFuseFunc as register_forward_fusion funcs
participant BwdFuseFunc as register_backward_fusion funcs
participant AutogradFn as _OperationFuserAutogradFunction
Caller->>OperationFuser: __call__(input, ...)
OperationFuser->>OperationFuser: maybe_fuse_ops()
Note over OperationFuser: Pass 1 — Joint forward-backward fusions
OperationFuser->>JointFuseFunc: _apply_fusions(basic_ops, joint_funcs)
JointFuseFunc-->>OperationFuser: joint_ops (may contain FusedOperation)
Note over OperationFuser: Pass 2 — Forward-only fusions (on joint_ops)
OperationFuser->>FwdFuseFunc: _apply_fusions(joint_ops, fwd_funcs)
FwdFuseFunc-->>OperationFuser: fwd_result
OperationFuser->>OperationFuser: _match_basic_ops(fwd_result, basic_ops) → _forward_ops
Note over OperationFuser: Pass 3 — Backward-only fusions (on joint_ops)
OperationFuser->>BwdFuseFunc: _apply_fusions(joint_ops, bwd_funcs)
BwdFuseFunc-->>OperationFuser: bwd_result
OperationFuser->>OperationFuser: _match_basic_ops(bwd_result, basic_ops) → _backward_ops
Caller->>AutogradFn: apply(input, fuser, ...)
AutogradFn->>AutogradFn: forward() — runs _forward_ops
AutogradFn->>AutogradFn: backward() — runs _backward_ops (same joint op instance)
Reviews (1): Last reviewed commit: "[PyTorch] Add joint forward-backward op ..." | Re-trigger Greptile |
| dsilu = s * (1 + y * (1 - s)) | ||
| dy = dout * dsilu | ||
|
|
||
| # Linear backward | ||
| dx = torch.matmul(dy, w).to(dtype=dtype) | ||
| dw = torch.matmul(dy.T, x).to(dtype=dtype) | ||
|
|
||
| # grad_input, grad params per basic op, grad extra inputs per basic op | ||
| return dx, [(dw,), ()], [(), ()] | ||
|
|
There was a problem hiding this comment.
_enabled flag silently drops the joint fusion on re-fuse
CustomLinearSiLU._enabled is set to False after the first call to fuse_ops, so if maybe_fuse_ops is triggered a second time (e.g. recipe type changes, first_op_requiring_backward shifts, or amax-history length changes), the fusion function returns the ops list unchanged. On that re-run _forward_ops and _backward_ops would revert to the two unfused basic ops, causing the assert isinstance(forward_ops[0][0], CustomLinearSiLU) assertions below to fail silently or with a confusing error rather than a clear "joint fusion was not reapplied" message. The current test is safe because it only calls the model once, but the pattern is fragile: any future extension that adds a second forward call (e.g., to test different recipe configurations) will break without an obvious explanation. Consider either resetting _enabled at the start of each maybe_fuse_ops-triggering call, or restructuring the fuse function to be idempotent (fuse only if the first op is not already a CustomLinearSiLU).
There was a problem hiding this comment.
The test is quick and hacky. The right fix would be a way to unregister fusions, but that's outside the scope of this PR.
| def fuse_ops( | ||
| ops: list[FusibleOperation], | ||
| **unused, | ||
| ) -> list[FusibleOperation]: | ||
| """Apply fusion the first time this function is called""" | ||
| if CustomLinearSiLU._enabled: | ||
| CustomLinearSiLU._enabled = False | ||
| op = CustomLinearSiLU(linear=ops[0], silu=ops[1]) | ||
| return [op] + ops[2:] |
There was a problem hiding this comment.
fuse_ops indexes ops without bounds or type guards
When _enabled is True, the function unconditionally accesses ops[0] and ops[1] and constructs CustomLinearSiLU(linear=ops[0], silu=ops[1]) without checking len(ops) >= 2 or that the ops are the expected types. If a future pipeline change reduces the number of basic ops to fewer than two (or the joint-fusion pass is called earlier in a different context), this raises an uncaught IndexError with no diagnostic message. Compared to the documented sliding-window pattern in op_fuser.rst — which uses isinstance checks before fusing — the test's fuse_ops skips these guards entirely, making it a less reliable reference for users adapting this code.
There was a problem hiding this comment.
The test is quick and hacky. This function prioritizes simplicity over robustness.
|
/te-ci pytorch |
Description
The op fuser assumes that forward fusions and backward fusions can be applied independently, so we enforce a contract that fused ops are interchangeable with the corresponding unfused ops. However, in the process of developing the grouped MLP fused ops, we have identified several optimizations that only make sense when the forward and backward fusions are performed together (e.g. recomputing FC2 input tensors instead of caching for backward).
This PR adds infrastructure to support joint forward-backward fusions, which relax the interchangeability contract so the forward and backward passes can have coupled implementations. Refactoring the grouped MLP ops as such a joint fused op will be deferred to a follow-up PR.
Type of change
Changes
Checklist: