Skip to content

fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088

Open
solos wants to merge 5 commits into
NVIDIA:mainfrom
solos:fix-topk
Open

fix(topk): fix UB and prevent vector load splitting in standalone_topk#3088
solos wants to merge 5 commits into
NVIDIA:mainfrom
solos:fix-topk

Conversation

@solos
Copy link
Copy Markdown

@solos solos commented Jun 5, 2026

Description

This PR refactors the vectorized processing logic in standalone_topk to eliminate C++ undefined behavior and ensure optimal GPU instruction emission. The changes address a critical issue where conditional vector loads were causing the compiler to split 128-bit memory transactions into scalar operations, degrading performance. Additionally, it fixes potential edge cases related to unsigned integer underflow and out-of-bounds memory access.

Fixes #

Fix C++ UB by replacing union type-punning with __builtin_memcpy.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Fix C++ UB by replacing union type-punning with __builtin_memcpy.
  • Prevent NVCC from splitting 128-bit vector loads by using unconditional pointer selection for safe companion reads.
  • Fix len_cast_for_sync underflow when len_cast == 0.
  • Guard against OOB reads when input length is smaller than sizeof(WideT).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

- Fix C++ UB by replacing union type-punning with `__builtin_memcpy`.
- Prevent NVCC from splitting 128-bit vector loads by using unconditional
  pointer selection for safe companion reads.
- Fix `len_cast_for_sync` underflow when `len_cast == 0`.
- Guard against OOB reads when input length is smaller than `sizeof(WideT)`.
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR refactors the vectorized memory-access helpers in standalone_topk.cuh to eliminate C++ undefined behavior and ensure NVCC emits full-width (LDG.E.128) 128-bit vector loads.

  • Replaces union-based type-punning with __builtin_memcpy in both vectorized_process overloads, fixing strict-aliasing UB.
  • Adds a len_cast > 0 guard in the sync overload to prevent idxT underflow when computing len_cast_for_sync and to ensure in_cast[0] is a valid read target before the companion-load trick is used.
  • Uses index clamping (safe_i = valid ? i : 0) instead of a pointer ternary for invalid-thread reads, ensuring NVCC sees an unconditional load and does not split it into scalar operations.

Confidence Score: 5/5

Safe to merge — changes are narrowly scoped to two helper overloads, each fix is logically complete, and no new code paths are introduced.

All four stated fixes (union UB, underflow, OOB companion read, predicated-load splitting) are correctly addressed. The len_cast > 0 guard guarantees in_cast[0] is a live element before it is used as a safe fallback index, the index-clamping approach avoids OOB pointer arithmetic flagged by the previous reviewer, and __builtin_memcpy is well-supported by NVCC for constant-size copies. No regressions in the scalar-tail path, which is unchanged.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/util/standalone_topk.cuh Refactors vectorized_process in both overloads: replaces union type-punning with __builtin_memcpy, adds len_cast > 0 guard against underflow/OOB, and uses index clamping instead of a pointer ternary for unconditional 128-bit vector loads.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[vectorized_process - sync overload] --> B{sizeof T >= sizeof WideT?}
    B -- yes --> C[Scalar loop over tid..len]
    B -- no --> D[Compute skip_cnt for alignment]
    D --> E[Compute in_cast, len_cast]
    E --> F{len_cast > 0?}
    F -- no --> G[Skip vector loop - avoids underflow and OOB]
    F -- yes --> H[Compute len_cast_for_sync = ceil len_cast / sync_width * sync_width]
    H --> I[Loop i = tid to len_cast_for_sync]
    I --> J{valid = i < len_cast?}
    J -- yes --> K[safe_i = i]
    J -- no --> L[safe_i = 0]
    K --> M[Unconditional load in_cast safe_i → LDG.E.128]
    L --> M
    M --> N[__builtin_memcpy wide_data → local_array]
    N --> O[Call f for each element with valid flag]
    G --> P[Scalar tail: tid < sync_width handles skip and remain elements]
    O --> P
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into fix-topk" | Re-trigger Greptile

Comment thread transformer_engine/common/util/standalone_topk.cuh Outdated
Clamp the load index to 0 for invalid padding threads instead of
selecting between two pointers. This eliminates language-level UB from
OOB pointer formation without changing the emitted vectorized global load.
Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.

@solos solos closed this Jun 6, 2026
@solos solos reopened this Jun 6, 2026
@solos
Copy link
Copy Markdown
Author

solos commented Jun 6, 2026

No. Using the union to achieve the vectorized accesses is a very wide practice in CUDA and endorsed by the NVCC compiler experts. As for the second issue - what is the case where you saw the issue (shape of the inputs etc.)? If len_cast is less than 0 then len_cast_for_sync is also going to be less than 0 and so the loop will just not run - I don't believe this PR actually changes the behavior of the kernel in any way.

Regarding the claim that the loop won't run when len_cast is negative: that's not always true. Due to C++'s truncation-toward-zero integer division. For example, if len_cast = -1 and sync_width = 32, then len_cast - 1 = -2, and -2 / 32 is 0 (not -1), so len_cast_for_sync becomes (0 + 1) * 32 = 32. The loop then executes multiple iterations even though len_cast is negative, with valid always false. While the invalid iterations are logically guarded by valid, the expression &in_cast[i] (or the pointer arithmetic in in_cast[i]) is still evaluated for out-of-bounds indices, which is undefined behavior. The added if (len_cast > 0) guard correctly prevents this edge case and does change the behavior for safety.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants