Added thd cudnn guard#3092
Conversation
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
277847e to
9755745
Compare
Greptile SummaryThis PR fixes a backend-selection hole where
Confidence Score: 5/5Safe to merge — the change only narrows an over-permissive backend-selection path; it cannot enable anything that wasn't supported before. Both the C++ guard and the Python guard mirror the cudnn-frontend's own support predicate (sm < 90 && version < 91801 → reject THD). The OR-masking hole in the mixed-format branch is cleanly closed by the new outer AND-guard, and the relaxed sm90 conditions are consistent across all three THD sub-clauses. The Python side correctly inspects q_format/kv_format rather than the combined qkv_format, covering the KV-cache mixed-layout case that was the original bug report. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend / get_attention_backend] --> B{F16 dtype?}
B -- No --> Z[Check other backends]
B -- Yes --> C{qkv_format check SBHD / BSHD / BHSD?}
C -- Yes --> E[Format OK]
C -- No --> D{Pure THD? qkv_format == NVTE_THD}
D -- Yes --> D1{sm >= 90 OR cuDNN >= 9.18.1?}
D1 -- No --> REJECT[NVTE_No_Backend]
D1 -- Yes --> D2{cuDNN version for MHA/GQA ok?}
D2 -- No --> REJECT
D2 -- Yes --> E
D -- No --> F{Mixed layout? q_format or kv_format == NVTE_THD}
F -- No --> G[Other format - SBHD/BSHD/BHSD per-format OK]
G --> E
F -- Yes --> F1{sm >= 90 OR cuDNN >= 9.18.1? PER-FORMAT CHECK}
F1 -- No --> REJECT
F1 -- Yes --> F2{cuDNN >= 9.7.0?}
F2 -- No --> REJECT
F2 -- Yes --> E
E --> H{NEW: Outer THD guard q_format == THD OR kv_format == THD?}
H -- No THD --> PASS[Continue to other checks]
H -- THD present --> H1{sm >= 90 OR cuDNN >= 9.18.1?}
H1 -- No --> REJECT
H1 -- Yes --> PASS
PASS --> BACKEND[NVTE_F16_Arbitrary_Seqlen]
Reviews (2): Last reviewed commit: "Update transformer_engine/pytorch/attent..." | Re-trigger Greptile |
…ls.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
/te-ci L0 |
Description
nvte_get_fused_attn_backendselects the F16 arbitrary-seqlen backend for mixed THD layouts (e.g.thd_bshd_bshd, used by KV caching) on Ampere/Ada GPUs, where cuDNN does not support THD (ragged offset) tensors before 9.18.1. Instead of falling back to another attention backend, the forward pass fails at cuDNN graph-build time:Reproduction (observed on A100 / sm80 with cuDNN 9.10.2):
Root cause
The supported-format reference is cudnn-frontend's SDPA support surface (
cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h), which rejects ragged tensors whensm < 90 && backend_version < 91801, matching the cuDNN 9.18.1 release notes ("Support for scaled dot-product attention backward with THD layout on RTX-PRO 6000 and Ampere-architecture GPUs for the F16 datatype has been added").TE's backend selection intends to enforce this ("THD requires sm90",
fused_attn.cppqkv-format clause), but mixed THD layouts slip through both of its layers:nvte_get_fused_attn_backend(fused_attn.cpp): forthd_bshd_bshd,nvte_get_qkv_formatreturnsNVTE_THD_2BSHD, notNVTE_THD, so the sm90-gated pure-THD branch never applies. The layout falls to the mixed-format branch, which ORs the q_format and kv_format conditions together —kv_format == NVTE_BSHDalone satisfies it, so the failed(q_format == NVTE_THD && sm_arch_ >= 90)disjunct is simply skipped instead of vetoing the backend. The clause implements "at least one of q/kv has a supported format" where the requirement is "both do".get_attention_backend(utils.py): the "Filter: QKV layout" section only gatesqkv_format == "thd", butget_qkv_formatreports"thd_2bshd"for these layouts, and the only architecture it checks for THD is sm120.As a result the F16 arbitrary-seqlen backend is selected and the failure surfaces later as a cuDNN graph-build error, rather than backend selection falling back to FlashAttention/UnfusedDotProductAttention.
Fix
Mirror the cudnn-frontend rule: THD requires
sm90+, or cuDNN 9.18.1+ on Ampere/Ada, in both selection layers:fused_attn.cpp: add a guard that closes the OR-masking hole for any layout involving THD:fused_attn.cpp: relax the threesm_arch_ >= 90THD conditions in the qkv-format clause to(sm_arch_ >= 90 || cudnn_runtime_version >= 91801), so that pure and mixed THD layouts are consistently enabled on Ampere/Ada with cuDNN 9.18.1+.get_attention_backend(utils.py): add the equivalent filter so PyTorch users get a clear debug message; it checksq_format/kv_formatrather thanqkv_formatto also cover thethd_2bshd/thd_2sbhdKV-cache layouts (this requires capturingkv_formatat theget_qkv_formatcall site, where it was previously discarded):With these guards, backend selection on sm80/sm89 with cuDNN < 9.18.1 returns
No_Backendfor THD layouts and falls back to FlashAttention/UnfusedDotProductAttention, andtests/pytorch/attention/test_kv_cache.pypasses (FusedAttention thd cases are skipped as unsupported).Fixes # (issue)
Type of change
Changes
nvte_get_fused_attn_backend, closing the hole where a validkv_formatmasked an unsupported THDq_formaton sm80/sm89 with cuDNN < 9.18.1.sm_arch_ >= 90THD conditions to also accept cuDNN >= 9.18.1, matching cudnn-frontend's support surface and the cuDNN 9.18.1 release notes.get_attention_backend(checkingq_format/kv_formatto coverthd_2bshd/thd_2sbhdKV-cache layouts) so PyTorch backend selection logs a clear reason and falls back cleanly.Checklist: