guarding max_logits fused attention for cudnn < 9.21.0#3091
guarding max_logits fused attention for cudnn < 9.21.0#3091francesco-bertolotti wants to merge 1 commit into
Conversation
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
d395938 to
ae53b5b
Compare
Greptile SummaryThis PR fixes a runtime crash when
Confidence Score: 5/5This PR is safe to merge; it adds a narrow version guard that disables a broken code path without touching any hot paths for cuDNN >= 9.21.0 users. Both changes are additive version-gating conditions that only activate on older cuDNN. The Python guard follows the established pattern in No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called\nreturn_max_logit=True] --> B{use_flash_attention?}
B -- yes --> C[Disable FlashAttention\n'Disabling FA for max_logit']
B -- no --> D{use_fused_attention\nAND cudnn_version < 9.21.0?}
C --> D
D -- yes --> E[Disable FusedAttention\n'Disabling FA for max_logit < 9.21']
D -- no --> F{fp8 AND fp8_dpa?}
E --> F
F -- yes --> G[Disable all backends\nreturn no backend]
F -- no --> H{use_fused_attention\nstill enabled?}
H -- yes --> I[Use FusedAttention\nvia cuDNN >= 9.21 unified softmax node]
H -- no --> J[Fallback: UnfusedDotProductAttention]
Reviews (1): Last reviewed commit: "guarding max_logits fused attention for ..." | Re-trigger Greptile |
Description
get_attention_backendselects FusedAttention forreturn_max_logit=Trueregardless of the cuDNN version, but cuDNN only supports emittingMaxalongside the softmaxStatsfrom cuDNN 9.21.0. On older cuDNN versions the forward pass fails at graph-build time with:Reproduction (observed on A100 / sm80 with cuDNN 9.10.2; the failure is cuDNN-version dependent, not architecture dependent):
Root cause
FusedAttention requests both the
StatsandMaxoutputs from the SDPA node (fused_attn_f16_arbitrary_seqlen.cu,sdpa_options.set_logit_max(Max)). In cudnn-frontend, that combination is only representable through the unified softmax descriptor (CUDNN_ATTR_OPERATION_SDPA_FWD_SOFTMAX_DESC), which requires a cuDNN backend >= 9.21.0:cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h: the composite SDPA node only wiresStats/Max/Sum_expinto aUnifiedSoftmaxNodewheneffective_cudnn_ver >= 92100; on older versions onlyStatsis set (viaCUDNN_ATTR_OPERATION_SDPA_FWD_STATSDESC).cudnn-frontend/include/cudnn_frontend/node/softmax.h(CompositeSoftmaxNode::pre_validate_node): the legacy composite softmax node only allows the output combinations{},{stats}, or{max, sum_exp}— the{stats, max}combination requested forreturn_max_logit=Trueis rejected, producing the error above.cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h: theMaxoutput is only added to the allowed outputs foreffective_cudnn_ver >= 92100.Fix
Add a filter in
get_attention_backendthat disables FusedAttention forreturn_max_logit=Truewhen the cuDNN version is below 9.21.0, falling back to UnfusedDotProductAttention (FlashAttention is already disabled formax_logit). This follows the existing pattern of cuDNN-version filters intransformer_engine/pytorch/attention/dot_product_attention/utils.py.The same policy is also enforced in
nvte_get_fused_attn_backend(transformer_engine/common/fused_attn/fused_attn.cpp) so that non-PyTorch frontends are covered as well — the FP8 branch already checks!return_max_logitunconditionally:Fixes # (issue)
Type of change
Changes
get_attention_backendwhenreturn_max_logit=Trueand cuDNN < 9.21.0, so backend selection falls back to UnfusedDotProductAttention instead of failing at cuDNN graph-build time.nvte_get_fused_attn_backend(fused_attn.cpp, F16 arbitrary-seqlen condition) so non-PyTorch frontends are covered as well.Checklist: