Fix GroupedLinear FP8 calibration loop#3101
Conversation
Signed-off-by: Minh Vu <vuhoangminh97@gmail.com>
506ef2d to
08c86d7
Compare
Greptile SummaryThis PR fixes a bug in
Confidence Score: 5/5Safe to merge — the change is a minimal, well-scoped removal of a redundant nested loop in the FP8 calibration path with no behavioural side effects beyond fixing the over-calibration. The diff is six lines touching a single code path that only runs during FP8 calibration. The original loop structure was unambiguously buggy (outer loop variable shadowed by two inner loops), and the replacement correctly calls each quantizer exactly once. The surrounding GEMM code, the grad bookkeeping, and every other code path are untouched. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[forward called with fp8_calibration=True] --> B[Build inputmats via torch.split]
B --> C[Build weights_fp8]
C --> D[general_grouped_gemm]
D --> E{fp8_calibration?}
E -- Yes --> F["for i in range(num_gemms)"]
F --> G["input_quantizers[i].calibrate(inputmats[i])"]
G --> H["weight_quantizers[i].calibrate(weights[i])"]
H --> F
F -- done --> I[cpu_offloading / grad bookkeeping]
E -- No --> I
Reviews (2): Last reviewed commit: "Merge branch 'main' into fix/grouped-lin..." | Re-trigger Greptile |
|
/te-ci pytorch |
Summary
GroupedLinearnum_gemmstimes for every GEMMValidation
python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.pygit diff --check -- transformer_engine/pytorch/module/grouped_linear.py