Skip to content

Added thd cudnn guard#3092

Open
francesco-bertolotti wants to merge 2 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-thd-guard
Open

Added thd cudnn guard#3092
francesco-bertolotti wants to merge 2 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-thd-guard

Conversation

@francesco-bertolotti
Copy link
Copy Markdown
Contributor

Description

nvte_get_fused_attn_backend selects the F16 arbitrary-seqlen backend for mixed THD layouts (e.g. thd_bshd_bshd, used by KV caching) on Ampere/Ada GPUs, where cuDNN does not support THD (ragged offset) tensors before 9.18.1. Instead of falling back to another attention backend, the forward pass fails at cuDNN graph-build time:

RuntimeError: transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:418
in function operator(): cuDNN Error: THD (ragged offset) is only supported in Hopper and
above : 80.

Reproduction (observed on A100 / sm80 with cuDNN 9.10.2):

pytest -x tests/pytorch/attention/test_kv_cache.py
# fails in the first non-skipped thd case, e.g.
# test_kv_cache[False-False-TransformerLayer-FusedAttention-False-thd-infer_0-dtype0]
# with qkv_layout = thd_bshd_bshd (non-paged); paged_kv_thd_* layouts fail the same way

Root cause

The supported-format reference is cudnn-frontend's SDPA support surface (cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h), which rejects ragged tensors when sm < 90 && backend_version < 91801, matching the cuDNN 9.18.1 release notes ("Support for scaled dot-product attention backward with THD layout on RTX-PRO 6000 and Ampere-architecture GPUs for the F16 datatype has been added").

TE's backend selection intends to enforce this ("THD requires sm90", fused_attn.cpp qkv-format clause), but mixed THD layouts slip through both of its layers:

  • nvte_get_fused_attn_backend (fused_attn.cpp): for thd_bshd_bshd, nvte_get_qkv_format returns NVTE_THD_2BSHD, not NVTE_THD, so the sm90-gated pure-THD branch never applies. The layout falls to the mixed-format branch, which ORs the q_format and kv_format conditions together — kv_format == NVTE_BSHD alone satisfies it, so the failed (q_format == NVTE_THD && sm_arch_ >= 90) disjunct is simply skipped instead of vetoing the backend. The clause implements "at least one of q/kv has a supported format" where the requirement is "both do".
  • get_attention_backend (utils.py): the "Filter: QKV layout" section only gates qkv_format == "thd", but get_qkv_format reports "thd_2bshd" for these layouts, and the only architecture it checks for THD is sm120.

As a result the F16 arbitrary-seqlen backend is selected and the failure surfaces later as a cuDNN graph-build error, rather than backend selection falling back to FlashAttention/UnfusedDotProductAttention.

Fix

Mirror the cudnn-frontend rule: THD requires sm90+, or cuDNN 9.18.1+ on Ampere/Ada, in both selection layers:

  1. fused_attn.cpp: add a guard that closes the OR-masking hole for any layout involving THD:
        // THD (ragged offset) support: Hopper+ (sm90) always; Ampere/Ada (sm80/sm89) only
        // from cuDNN 9.18.1 ("SDPA backward with THD layout on RTX-PRO 6000 and
        // Ampere-architecture GPUs"; fprop on Ampere is undocumented, so gate both).
        // The qkv format clause above ORs q_format and kv_format conditions together, so a
        // valid kv_format (e.g. paged_kv_thd_bshd_bshd, where kv is BSHD) would otherwise
        // mask an invalid THD q_format on sm80 with older cuDNN.
        ((q_format != NVTE_QKV_Format::NVTE_THD && kv_format != NVTE_QKV_Format::NVTE_THD) ||
         sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
  1. fused_attn.cpp: relax the three sm_arch_ >= 90 THD conditions in the qkv-format clause to (sm_arch_ >= 90 || cudnn_runtime_version >= 91801), so that pure and mixed THD layouts are consistently enabled on Ampere/Ada with cuDNN 9.18.1+.

  2. get_attention_backend (utils.py): add the equivalent filter so PyTorch users get a clear debug message; it checks q_format/kv_format rather than qkv_format to also cover the thd_2bshd/thd_2sbhd KV-cache layouts (this requires capturing kv_format at the get_qkv_format call site, where it was previously discarded):

    # THD support on Ampere/Ada requires cuDNN 9.18.1+ ("SDPA backward with THD layout on
    # RTX-PRO 6000 and Ampere-architecture GPUs"). Check q_format/kv_format, not just
    # qkv_format, since KV-cache layouts (e.g. paged_kv_thd_bshd_bshd) have
    # qkv_format = thd_2bshd.
    if "thd" in (q_format, kv_format) and device_compute_capability < (9, 0):
        if cudnn_version < (9, 18, 1):
            if use_fused_attention:
                logger.debug(
                    "Disabling FusedAttention as qkv_format = thd is not supported for"
                    " compute capability < sm90 and cuDNN version < 9.18.1"
                )
            use_fused_attention = False

With these guards, backend selection on sm80/sm89 with cuDNN < 9.18.1 returns No_Backend for THD layouts and falls back to FlashAttention/UnfusedDotProductAttention, and tests/pytorch/attention/test_kv_cache.py passes (FusedAttention thd cases are skipped as unsupported).

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a THD architecture/version guard to the F16 arbitrary-seqlen condition in nvte_get_fused_attn_backend, closing the hole where a valid kv_format masked an unsupported THD q_format on sm80/sm89 with cuDNN < 9.18.1.
  • Relax the existing sm_arch_ >= 90 THD conditions to also accept cuDNN >= 9.18.1, matching cudnn-frontend's support surface and the cuDNN 9.18.1 release notes.
  • Add the equivalent filter to get_attention_backend (checking q_format/kv_format to cover thd_2bshd/thd_2sbhd KV-cache layouts) so PyTorch backend selection logs a clear reason and falls back cleanly.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 5, 2026

Greptile Summary

This PR fixes a backend-selection hole where nvte_get_fused_attn_backend (C++) and get_attention_backend (Python) incorrectly allowed the cuDNN F16 arbitrary-seqlen backend for mixed THD layouts (e.g. thd_bshd_bshd, paged_kv_thd_bshd_bshd) on Ampere/Ada GPUs with cuDNN < 9.18.1, causing a hard crash at cuDNN graph-build time.

  • C++ (fused_attn.cpp): Adds a cross-format THD guard that blocks backend selection when either q_format or kv_format is NVTE_THD on sm_arch_ < 90 with cudnn_runtime_version < 9.18.1, closing the OR-masking hole where a valid kv_format (e.g. NVTE_BSHD) masked an invalid THD q_format; also relaxes the existing sm_arch_ >= 90 conditions to (sm_arch_ >= 90 || cudnn_runtime_version >= 91801) for the pure-THD branch.
  • Python (utils.py): Adds an equivalent filter checking q_format/kv_format (not just qkv_format) for \"thd\" presence, so KV-cache layouts like thd_2bshd are caught and FusedAttention is disabled with a clear debug message before reaching cuDNN.

Confidence Score: 5/5

Safe to merge — the change only narrows an over-permissive backend-selection path; it cannot enable anything that wasn't supported before.

Both the C++ guard and the Python guard mirror the cudnn-frontend's own support predicate (sm < 90 && version < 91801 → reject THD). The OR-masking hole in the mixed-format branch is cleanly closed by the new outer AND-guard, and the relaxed sm90 conditions are consistent across all three THD sub-clauses. The Python side correctly inspects q_format/kv_format rather than the combined qkv_format, covering the KV-cache mixed-layout case that was the original bug report.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds an outer THD guard and relaxes three sm90 conditions to also accept cuDNN 9.18.1+; logic is correct and consistent with cudnn-frontend's support surface.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds a Python-layer guard checking q_format/kv_format for "thd" to catch KV-cache mixed layouts missed by the existing qkv_format == "thd" branch; correctly placed after the sm120 block and consistently mirrors the C++ logic.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend / get_attention_backend] --> B{F16 dtype?}
    B -- No --> Z[Check other backends]
    B -- Yes --> C{qkv_format check SBHD / BSHD / BHSD?}
    C -- Yes --> E[Format OK]
    C -- No --> D{Pure THD? qkv_format == NVTE_THD}
    D -- Yes --> D1{sm >= 90 OR cuDNN >= 9.18.1?}
    D1 -- No --> REJECT[NVTE_No_Backend]
    D1 -- Yes --> D2{cuDNN version for MHA/GQA ok?}
    D2 -- No --> REJECT
    D2 -- Yes --> E
    D -- No --> F{Mixed layout? q_format or kv_format == NVTE_THD}
    F -- No --> G[Other format - SBHD/BSHD/BHSD per-format OK]
    G --> E
    F -- Yes --> F1{sm >= 90 OR cuDNN >= 9.18.1? PER-FORMAT CHECK}
    F1 -- No --> REJECT
    F1 -- Yes --> F2{cuDNN >= 9.7.0?}
    F2 -- No --> REJECT
    F2 -- Yes --> E
    E --> H{NEW: Outer THD guard q_format == THD OR kv_format == THD?}
    H -- No THD --> PASS[Continue to other checks]
    H -- THD present --> H1{sm >= 90 OR cuDNN >= 9.18.1?}
    H1 -- No --> REJECT
    H1 -- Yes --> PASS
    PASS --> BACKEND[NVTE_F16_Arbitrary_Seqlen]
Loading

Reviews (2): Last reviewed commit: "Update transformer_engine/pytorch/attent..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
…ls.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented Jun 5, 2026

/te-ci L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants