Add GB10 (sm_121 / DGX Spark) support to low-latency-llama#11
Open
chauhang wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR makes the low-latency Llama megakernel run on GB10 / NVIDIA DGX Spark (
sm_121, consumer-class Blackwell):torch.compilebaseline togenerate.pyfor a fair production comparison.On Llama-3.2-1B batch-1 decode, the megakernel is verified correct (the repo's
diff_test) and faster than PyTorch:torch(eager)torch.compile(reduce-overhead + cudagraphs)mk(megakernel)mkis 1.49× over eager and 1.33× over compiled+cudagraphed PyTorch, saturating GB10's memory bandwidth (see Performance below). All changes are gated under#ifdef KITTENS_SM120, so H100/B200 builds are byte-identical.What's in this PR (all
#ifdef KITTENS_SM120)GB10 forces two adjustments — (a) re-tile the megakernel's shared-memory page layout from 13 → 5 pages, and (b) disable a cluster-only codegen path the chip can't execute — plus build/config plumbing and the benchmark baseline. Everything else is unchanged.
__cluster_dims__for SM120 (include/megakernel.cuh). A size-1 cluster attribute still makes nvcc emit thread-block-cluster addressing (S2UR SR_CgaCtaId+ cluster-scoped mbarriers) for all shared-memory barrier waits — an illegal instruction on sm_121 (consumer Blackwell has no cluster engine). The cluster syncs were already guarded byif (CLUSTER_BLOCKS > 1), so size-1 clusters are unaffected.demos/low-latency-llama/matvec_pipeline.cuh). The matvec ops were built for 13 pages (INPUT_PIPELINE_STAGES(3) × STAGE_PAGES(4)+ 1 activation), which overran GB10's 5-entry page table (OOB). Fix:INPUT_PIPELINE_STAGES3 → 1 (4 weight + 1 activation = exactly 5; outputs live in scratch, soOUTPUT_PIPELINE_STAGESstays 3), andrelease_lid()(which returns the order an op's shared-memory pages are recycled into for the next op) → identityret_order[5] = {0,1,2,3,4}(the permutation only sets recycle order — actual page readiness is still gated by thepage_finishedsemaphores).release_lidfix (attention_partial.cu,attention_reduction.cu). These ops have their ownrelease_lidwithret_order[13]; the >5 return values corrupted the next op's 5-entrypid_order, which surfaced as an OOB read ino_proj(the op after attention). Both fixed to identity[5].include/config.cuh— SM120 overrides:MAX_SHARED_MEMORY = 99 KB,NUM_PAGES == 5assert,GB10_SM_COUNT = 48.llama.cuhtakes the SM120 SM-count branch.generate.py— adds amode=compilebaseline (torch.compilereduce-overhead + cudagraphs) — the fair production baseline to compare the megakernel against.Dependency — needs a GB10-enabled ThunderKittens
The megakernel builds against ThunderKittens (TK). Megakernels currently pins TK as a submodule on an older branch (
bvm-single-ctrl-pre-new-warps) that predates GB10/SM120 support, so it won't build for GB10 as-is.The TK-side GB10 support this PR relies on — TK's
sm_121build target plus a warpgroup-MMA shim that lets H100-style kernels run on consumer Blackwell — is a companion PR: HazyResearch/ThunderKittens#204 (HazyResearch/ThunderKittens#204).To build or test this PR, set
THUNDERKITTENS_ROOTto a TK checkout that has #204's changes (e.g.chauhang/ThunderKittens@gb10-wgmma-shim) rather than the pinned submodule. No Megakernels code changes are needed for that — the megakernel is source-compatible with the newer TK and compiles unchanged. The newGPU=GB10Makefile target handles the build flags.Validation (real GB10, CUDA 13.0,
sm_121a, Llama-3.2-1B batch-1)Correctness is established the way the repo's own harness does —
diff_test.pyruns the identical instruction list through the Python-VM reference (pyvm) and the megakernel (mk) on byte-identical inputs (seeded; hidden states + K/V cache copied from the reference into the kernel), then compares every per-op intermediate tensor and the final logits — max absolute diff and mean symmetric relative diff (rdiff = 2|a−b| / (|a|+|b|)) — against the bf16 noise band.diff_test, all 16 layers passes at bf16 tolerance: attn intermediates exact (0.0), k/v cache adiff ~0.03, logits max adiff 0.125 / mean rdiff 0.044 — accumulated bf16 rounding, not port-induced (the 1-stage change alters when weights load, not the arithmetic).silu_out(the MLP matvec — unchanged CUDA-core code), not the attention MMA-shim or the page-layout port — so neither change introduced numerical error.Performance — measured on real GB10 (batch-1 decode)
Batch-1 decode is memory-bandwidth bound (reads all 2.47 GB of bf16 weights/token), so the right denominator is achievable bandwidth, not the spec: a microbenchmark tops out at 236 GB/s (86% of the 273 GB/s LPDDR5X spec, which is not attainable on GB10). Against that,
mksaturates the memory system:torch(eager)torch.compile(reduce-overhead + cudagraphs)mk(megakernel)Four independent tools corroborate the picture:
mk95.9 tok/s → 1.49× eager, 1.33× compilemkeffective BWmkshare of GPU timeThe throughput-derived (95.9 tok/s × 2.47 GB ≈ 237 GB/s) and duration-derived (2.47 GB ÷ 10.52 ms = 235 GB/s) numbers bracket the 236 GB/s achievable ceiling. ncu confirms it's bandwidth-bound, not compute-bound, and nsys confirms the fusion: a single persistent kernel does 96.5% of all GPU work, with only µs-scale torch glue (embedding, KV-cache concat, SiLU, elementwise) around it. (Bandwidth is read directly from the microbench and cross-checked by the duration derivation — ncu's DRAM counters and nsys's BW path don't report on GB10's unified memory.)
Build & run
Limitations / follow-ups
Co-created with Claude Code