Skip to content

guarding max_logits fused attention for cudnn < 9.21.0#3091

Open
francesco-bertolotti wants to merge 1 commit into
NVIDIA:mainfrom
francesco-bertolotti:f14-max-logits-guard
Open

guarding max_logits fused attention for cudnn < 9.21.0#3091
francesco-bertolotti wants to merge 1 commit into
NVIDIA:mainfrom
francesco-bertolotti:f14-max-logits-guard

Conversation

@francesco-bertolotti
Copy link
Copy Markdown
Contributor

Description

get_attention_backend selects FusedAttention for return_max_logit=True regardless of the cuDNN version, but cuDNN only supports emitting Max alongside the softmax Stats from cuDNN 9.21.0. On older cuDNN versions the forward pass fails at graph-build time with:

RuntimeError: transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:419
in function operator(): cuDNN Error: CompositeSoftmaxNode can only output certain
combinations of stats, max and sum_exp: stats only, max and sum_exp only, or none of the above.

Reproduction (observed on A100 / sm80 with cuDNN 9.10.2; the failure is cuDNN-version dependent, not architecture dependent):

pytest -x tests/pytorch/attention/test_attention.py::test_dpa_max_logit
# fails in the first non-skipped case, e.g.
# test_dpa_max_logit[sbhd_sbhd_sbhd-max_logit_1-model_configs0-dtype0]

Root cause

FusedAttention requests both the Stats and Max outputs from the SDPA node (fused_attn_f16_arbitrary_seqlen.cu, sdpa_options.set_logit_max(Max)). In cudnn-frontend, that combination is only representable through the unified softmax descriptor (CUDNN_ATTR_OPERATION_SDPA_FWD_SOFTMAX_DESC), which requires a cuDNN backend >= 9.21.0:

  • cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h: the composite SDPA node only wires Stats/Max/Sum_exp into a UnifiedSoftmaxNode when effective_cudnn_ver >= 92100; on older versions only Stats is set (via CUDNN_ATTR_OPERATION_SDPA_FWD_STATSDESC).
  • cudnn-frontend/include/cudnn_frontend/node/softmax.h (CompositeSoftmaxNode::pre_validate_node): the legacy composite softmax node only allows the output combinations {}, {stats}, or {max, sum_exp} — the {stats, max} combination requested for return_max_logit=True is rejected, producing the error above.
  • cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h: the Max output is only added to the allowed outputs for effective_cudnn_ver >= 92100.

Fix

Add a filter in get_attention_backend that disables FusedAttention for return_max_logit=True when the cuDNN version is below 9.21.0, falling back to UnfusedDotProductAttention (FlashAttention is already disabled for max_logit). This follows the existing pattern of cuDNN-version filters in transformer_engine/pytorch/attention/dot_product_attention/utils.py.

# Filter: Return max_logit
if return_max_logit:
    ...
    # FusedAttention emits max_logit alongside the softmax stats, which cuDNN only
    # supports through the unified softmax node introduced in cuDNN 9.21.0. On older
    # cuDNN the composite softmax node rejects the stats+max combination, so fall back
    # to UnfusedDotProductAttention.
    if use_fused_attention and cudnn_version < (9, 21, 0):
        use_fused_attention = False
        logger.debug("Disabling FusedAttention for max_logit for cuDNN < 9.21.0")

The same policy is also enforced in nvte_get_fused_attn_backend (transformer_engine/common/fused_attn/fused_attn.cpp) so that non-PyTorch frontends are covered as well — the FP8 branch already checks !return_max_logit unconditionally:

        // max_logit
        // pre-9.21: no (the composite softmax node rejects the Stats + Max output combination)
        // 9.21+: yes (Stats + Max via the unified softmax node)
        (!return_max_logit || cudnn_runtime_version >= 92100) &&

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Disable FusedAttention in get_attention_backend when return_max_logit=True and cuDNN < 9.21.0, so backend selection falls back to UnfusedDotProductAttention instead of failing at cuDNN graph-build time.
  • Enforce the same requirement in nvte_get_fused_attn_backend (fused_attn.cpp, F16 arbitrary-seqlen condition) so non-PyTorch frontends are covered as well.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR fixes a runtime crash when return_max_logit=True is used with FusedAttention on cuDNN versions below 9.21.0. The composite softmax node in older cuDNN rejects the combined Stats + Max output request, causing a graph-build failure; the fix gates FusedAttention on cudnn_version >= (9, 21, 0) for this path, falling back to UnfusedDotProductAttention.

  • utils.py: Adds a version guard in get_attention_backend that disables FusedAttention for return_max_logit=True when cudnn_version < (9, 21, 0), consistent with the existing pattern of cuDNN-version filters in the same function.
  • fused_attn.cpp: Extends the F16 arbitrary-seqlen eligibility condition with (!return_max_logit || cudnn_runtime_version >= 92100), covering non-PyTorch frontends; the FP8 branch already had an unconditional !return_max_logit guard.

Confidence Score: 5/5

This PR is safe to merge; it adds a narrow version guard that disables a broken code path without touching any hot paths for cuDNN >= 9.21.0 users.

Both changes are additive version-gating conditions that only activate on older cuDNN. The Python guard follows the established pattern in get_attention_backend, and the C++ condition is inserted inline with identical logic. No existing behavior is changed for cuDNN >= 9.21.0 or for the FP8 branch.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds a cuDNN version guard for FusedAttention when return_max_logit=True; follows existing patterns, correctly placed before the FP8 total-disable block.
transformer_engine/common/fused_attn/fused_attn.cpp Adds (!return_max_logit

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called\nreturn_max_logit=True] --> B{use_flash_attention?}
    B -- yes --> C[Disable FlashAttention\n'Disabling FA for max_logit']
    B -- no --> D{use_fused_attention\nAND cudnn_version < 9.21.0?}
    C --> D
    D -- yes --> E[Disable FusedAttention\n'Disabling FA for max_logit < 9.21']
    D -- no --> F{fp8 AND fp8_dpa?}
    E --> F
    F -- yes --> G[Disable all backends\nreturn no backend]
    F -- no --> H{use_fused_attention\nstill enabled?}
    H -- yes --> I[Use FusedAttention\nvia cuDNN >= 9.21 unified softmax node]
    H -- no --> J[Fallback: UnfusedDotProductAttention]
Loading

Reviews (1): Last reviewed commit: "guarding max_logits fused attention for ..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant