fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088
fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088solos wants to merge 5 commits into
Conversation
- Fix C++ UB by replacing union type-punning with `__builtin_memcpy`. - Prevent NVCC from splitting 128-bit vector loads by using unconditional pointer selection for safe companion reads. - Fix `len_cast_for_sync` underflow when `len_cast == 0`. - Guard against OOB reads when input length is smaller than `sizeof(WideT)`.
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors the vectorized memory-access helpers in
Confidence Score: 5/5Safe to merge — changes are narrowly scoped to two helper overloads, each fix is logically complete, and no new code paths are introduced. All four stated fixes (union UB, underflow, OOB companion read, predicated-load splitting) are correctly addressed. The len_cast > 0 guard guarantees in_cast[0] is a live element before it is used as a safe fallback index, the index-clamping approach avoids OOB pointer arithmetic flagged by the previous reviewer, and __builtin_memcpy is well-supported by NVCC for constant-size copies. No regressions in the scalar-tail path, which is unchanged. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[vectorized_process - sync overload] --> B{sizeof T >= sizeof WideT?}
B -- yes --> C[Scalar loop over tid..len]
B -- no --> D[Compute skip_cnt for alignment]
D --> E[Compute in_cast, len_cast]
E --> F{len_cast > 0?}
F -- no --> G[Skip vector loop - avoids underflow and OOB]
F -- yes --> H[Compute len_cast_for_sync = ceil len_cast / sync_width * sync_width]
H --> I[Loop i = tid to len_cast_for_sync]
I --> J{valid = i < len_cast?}
J -- yes --> K[safe_i = i]
J -- no --> L[safe_i = 0]
K --> M[Unconditional load in_cast safe_i → LDG.E.128]
L --> M
M --> N[__builtin_memcpy wide_data → local_array]
N --> O[Call f for each element with valid flag]
G --> P[Scalar tail: tid < sync_width handles skip and remain elements]
O --> P
Reviews (4): Last reviewed commit: "Merge branch 'main' into fix-topk" | Re-trigger Greptile |
Clamp the load index to 0 for invalid padding threads instead of selecting between two pointers. This eliminates language-level UB from OOB pointer formation without changing the emitted vectorized global load.
ptrendx
left a comment
There was a problem hiding this comment.
No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.
Regarding the claim that the loop won't run when len_cast is negative: that's not always true. Due to C++'s truncation-toward-zero integer division. For example, if len_cast = -1 and sync_width = 32, then len_cast - 1 = -2, and -2 / 32 is 0 (not -1), so len_cast_for_sync becomes (0 + 1) * 32 = 32. The loop then executes multiple iterations even though len_cast is negative, with valid always false. While the invalid iterations are logically guarded by valid, the expression &in_cast[i] (or the pointer arithmetic in in_cast[i]) is still evaluated for out-of-bounds indices, which is undefined behavior. The added if (len_cast > 0) guard correctly prevents this edge case and does change the behavior for safety. |
Description
This PR refactors the vectorized processing logic in standalone_topk to eliminate C++ undefined behavior and ensure optimal GPU instruction emission. The changes address a critical issue where conditional vector loads were causing the compiler to split 128-bit memory transactions into scalar operations, degrading performance. Additionally, it fixes potential edge cases related to unsigned integer underflow and out-of-bounds memory access.
Fixes #
Fix C++ UB by replacing union type-punning with __builtin_memcpy.
Type of change
Changes
Please list the changes introduced in this PR:
__builtin_memcpy.len_cast_for_syncunderflow whenlen_cast == 0.sizeof(WideT).Checklist: