Writing A Megakernel For LLM Decode - A Worklog
This is the story of a megakernel: the 28-layer text decoder of Qwen3-ASR-0.6B, plus its lm_head, running as a single persistent CUDA kernel on an RTX 5080. The baseline fires about 466 kernel launches per generated token. This fires one.
Every op in it is hand-written - the GEMVs, the norms, the rope, the KV-cache write, the attention, the sampling argmax - and every op was validated against vLLM's own modules before I trusted it. The end-to-end test is not a tolerance check: the megakernel picks the same greedy token as vLLM, 7 out of 7 decode steps, with wide logit margins.
The performance ending is mixed, and I'm going to report it the way it landed. The megakernel decodes at 3.49 ms/token. That's about 7x faster than vLLM running eager. It is also 1.65x slower than vLLM's CUDA-graph path. Along the way three of my fusion ideas lost to measurement - one of them by 19x - and two attention optimizations worked exactly as designed and still got rejected. Those negative results taught me more than the wins, so they're all in here.
I'll walk through it the way it happened: the profile that motivated the whole thing, the design, the op-by-op build, the experiments, the optimization pass, and an honest accounting of the gap that remains.
Part I: why - the shape of M=1 decode
Qwen3-ASR-0.6B is a speech-to-text model: a Whisper-style audio encoder feeds a 28-layer Qwen3 LLM that autoregressively writes the transcript. The decoder is standard Qwen3: hidden 1024, 16 query heads / 8 KV heads (GQA), head_dim 128, SiLU MLP at 3072, RMSNorm everywhere including per-head QK-norm, rope. My target is the decode loop only - batch=1, one token at a time. The encoder and prefill are a different, larger-tile regime and stay as normal kernels.
At batch=1 decode, every matmul in the model is a matrix-vector product. M=1. That single fact sets everything:
- A GEMV reads each weight byte once and does ~1 flop with it. Arithmetic intensity is under 1 flop/byte, hopelessly memory-bound. Tensor cores have no matrix to tile - they'd be doing 16x16 tiles of which one row is real.
- So decoding one token means streaming the entire model - 1.19 GB of bf16 weights - through the GPU, once per token.
- The RTX 5080 sustains about 946 GB/s of real DRAM bandwidth (measured, 98.5% of the 960 spec). 1.19 GB / 946 GB/s = 1.26 ms. That's the floor. No decode of this model on this GPU beats 1.26 ms/token, megakernel or not.
That floor is the number every result below is measured against. "Fast" means "close to 1.26 ms".
What the profile actually said
Before writing any kernel code I profiled vLLM's decode with Nsight Systems and Nsight Compute, eager mode on purpose - CUDA graphs fuse the launch stream and hide exactly the per-op structure I wanted to see.
One decode step is ~466 kernel launches that together do about 2.2 ms of GPU work. The mean kernel is 4.77 us long. And per-op, the counters say the GPU is never actually busy:
| signal (median over decode ops) | value | reading |
|---|---|---|
| waves per SM | < 1 for most ops | the op can't fill the GPU even once |
| achieved occupancy | 6-43% | warp slots mostly empty |
| SM throughput | 2-33% | compute pipes idle |
| DRAM throughput | 0.3-47% | memory idle |
Both pipes idle at the same time. The model isn't compute-bound or memory-bound at decode - it's launch-bound and latency-bound. There simply isn't enough work in any single op to cover 84 SMs, and the ops run strictly one after another on one stream, so every launch is a global barrier: kernel N+1 can't start until kernel N's last CTA retires, even when kernel N used 4 CTAs out of 84.
The stall reasons are the interesting part. The glue ops (norms, residuals, activation, rope, cache write) all stall on long_scoreboard - warps sitting there waiting for a global-memory load. Which load? The previous op's output. Op N writes the activation to HBM, op N+1 reads it back microseconds later, and the warps eat the full DRAM round-trip latency in between. The activation is 2-6 KB. It bounces through HBM seventeen times per layer because that's what a kernel boundary is: state dies at the end of a grid, the next grid re-reads it.
One caveat for honesty: in eager mode with a profiler attached, the wall-clock reads ~92% idle, and that number is inflated by Python dispatch that CUDA graphs also remove. I'm not quoting it as a hardware fact. The per-op counters - waves, occupancy, stalls - are profiler-independent, and they're what the design is built on. This also frames what a megakernel has to prove: beating eager just means deleting launch overhead, which CUDA graphs already do. The megakernel's real claim has to be the on-chip part - keeping intermediates out of HBM and removing the per-op barriers. Hold that thought for the ending.
Part II: the design
Five decisions, locked before building, mostly unchanged after:
- One persistent kernel. A cooperative grid of 84 CTAs - one per SM - launched once, resident across the whole forward pass.
grid.sync()between ops, for v1. A cooperative grid can barrier itself device-side. Coarse, correct, and the honest starting point; finer-grained signalling is where cross-op overlap would come from later.- Weights partitioned, activation replicated. Each CTA owns a slice of every projection's output rows and streams its slice of the weight from HBM. The activation vector is 1024 floats - 4 KB - so every CTA just keeps a full copy in shared memory. Partition the big thing, replicate the small thing.
- Intermediates never touch HBM. Registers and SMEM where possible; a small global scratch for cross-CTA handoffs, which at 2-6 KB stays L2-resident. The round-trip the profile flagged was HBM; an L2 round-trip costs ~10x less.
- Static schedule. Fixed compile-time work split, no dynamic work queue. When a CTA's job is "my slice of every op in the layer" instead of "this one tiny op", the per-op imbalance from the profile disappears on its own.
The per-layer chain is 17 ops:
times 28 layers, then a final norm, the lm_head (1024 -> 151936, the fat one), and argmax.
The GEMV that carries most of the weight stream is warp-per-row: each warp owns output rows in a grid-stride loop, lanes split the 1024-wide dot product with vectorized int4 loads, and a warp shuffle folds the partial sums:
__device__ __forceinline__
void op_gemv(const bf16* __restrict__ W, const float* __restrict__ x,
float* __restrict__ out, int in_dim, int out_dim)
{
const int lane = threadIdx.x & 31; // my position in the warp (0-31)
const int gwarp = blockIdx.x * (blockDim.x >> 5) + (threadIdx.x >> 5); // my warp's global id
const int nwarp = gridDim.x * (blockDim.x >> 5); // total warps in the grid
// each warp owns output rows gwarp, gwarp+nwarp, ... (grid-stride over rows)
for (int row = gwarp; row < out_dim; row += nwarp) {
const bf16* wr = W + (size_t)row * in_dim; // this row of the weight matrix
float acc = 0.f;
// the 32 lanes split the row's dot product; each lane grabs
// 8 bf16s (16 bytes) per step with one int4 load - coalesced,
// because neighboring lanes read neighboring 16-byte chunks
for (int c = lane * 8; c < in_dim; c += 32 * 8) {
int4 raw = *reinterpret_cast<const int4*>(wr + c);
const bf16* wv = reinterpret_cast<const bf16*>(&raw);
#pragma unroll
for (int j = 0; j < 8; ++j) {
acc += __bfloat162float(wv[j]) * x[c + j]; // fp32 accumulate
}
}
acc = warpReduceSum(acc); // fold 32 partial sums -> lane 0
if (lane == 0) {
out[row] = __bfloat162float(__float2bfloat16(acc)); // bf16-round on store, matches vLLM
}
}
}
Nothing exotic. I benchmarked the optimization ladder on the lm_head shape (1024 x 151936, ~300 MB - big enough that L2 can't fake the numbers) before settling here:
| GEMV variant | % of DRAM peak | note |
|---|---|---|
| naive, thread-per-row | 31% | each thread strides the row alone - uncoalesced |
| warp-per-row, coalesced | ~100% | the whole win, 3.3x |
| + int4 vectorize / x in SMEM / persistence | ~same | already saturated, hygiene only |
Coalescing is the entire win at M=1; the rest is plumbing. One measurement trap from the same bench: on the small qkv shape (8 MiB) the counter read over 200% of DRAM peak - impossible, and the tell that the tensor was L2-resident across iterations. Cache numbers, not DRAM numbers. Every DRAM% in this post is quoted on shapes that don't fit in L2, or from NCU's per-kernel DRAM counters directly.
Part III: the build, op by op
The golden harness
A 28-layer chain of hand-written ops fails silently. A bug in op 7 of layer 3 doesn't crash - it just produces slightly wrong numbers that wash through 25 more layers of matmuls and come out as a plausible-looking but wrong token. So the rule was: no op gets trusted until it matches a golden.
The golden generator hooks vLLM's actual model modules - not a reimplementation - at a fixed decode state and dumps every op's input and output to disk as f32. Each kernel diffs against its golden the day it's written. One subtlety worth recording: the goldens are dumps of bf16 tensors, so a correct fp32 kernel diffs at ~2e-3 relative error. That's bf16 epsilon, not a bug. The pass gate was rel < 2e-2, an order of magnitude above the noise.
Seventeen ops went green in order. Two findings from the bring-up are worth pulling out.
The rope turned out to be simpler than the config says. The model declares mRoPE - multimodal rope with sections [24,20,20], three interleaved position axes. It looked like the riskiest op in the chain, so I golden-verified the convention before building anything general. The result is; at text decode, all three position axes carry the same scalar position, and the whole thing collapses to standard NeoX rotate_half. The [24,20,20] machinery only matters for the audio-prefix prefill, which is out of scope. All the mRoPE complexity lives in host-side cos/sin construction; the rotation kernel never sees it:
// op7: at decode the 3 mRoPE axes collapse to plain NeoX rotate_half (golden-proved)
float x0 = x[base];
float x1 = x[base + HALF];
out[base] = x0 * cos[d] - x1 * sin[d];
out[base + HALF] = x1 * cos[d] + x0 * sin[d];
bf16 drift over 28 layers saturates instead of compounding. Before chaining all 28 layers I wanted to know whether bf16 rounding error compounds geometrically - whether the residual stream needs to be fp32. Measured against an fp32-truth reference at 4, 8, 16, 28 layers: there's a one-time x3 jump in the first layer, and after that the error grows about x1.07 per layer and then flattens into a 1-5e-2 band. Layer 27 sits at 2.2e-2. The reason it saturates: every op re-quantizes to bf16, which caps how much error a layer can hand to the next. So the residual stream stays bf16. I measured the curve instead of assuming fp32 was needed, and it wasn't.
The fusion that ran 19x slower
The first natural fusion opportunity in the layer looked perfect, and it's the single most instructive failure in this project.
After the q and k projections, Qwen3 applies a per-head RMSNorm - a normalization over each head's 128 dims. Two ways to structure "GEMV, then per-head norm":
- Fused, head-tiled: each warp owns a whole head. It computes that head's 128 GEMV rows into registers, RMSNorms them in place, writes out the normed result. The intermediate never leaves the chip. This is the megakernel thesis in miniature - it's why you build a megakernel.
- Split: warp-per-row GEMV writes raw q/k to scratch, and a second warp-per-head pass reads it back and norms it. Costs a round-trip.
I built the fused version first, because keeping data on-chip is the whole point. Then I benchmarked both (2000-iteration average, both correct against golden):
q (16 heads) k (8 heads)
FUSED 0.1168 ms 0.1174 ms
SPLIT 0.0082 ms 0.0061 ms
speedup 14.3x 19.1x
The fused version isn't a little slower. It's an order of magnitude slower: 14.3x on the q projection, 19.1x on k. When this post says "19x", that's the k number - q pays less only because it has twice the heads and therefore twice the active SMs. Same disease, different dose.
The mechanism: head-tiling assigns one warp per head, and at M=1 this model has 16 q-heads and 8 k-heads. Sixteen warps fit in two CTAs. So the fused kernel streams a multi-megabyte weight matrix through 2 of the 84 SMs while the other 82 idle at the barrier. NCU makes it plain. The fused path runs DRAM at about 1.9% of peak; the split path, with warp-per-row spreading rows across all 84 SMs, runs at about 33%. A GEMV's speed is its memory throughput, so that ratio is the speedup.
There's a detail in the fused row that tells the whole story. q takes 0.1168 ms and k takes 0.1174 ms - the same - even though q's weight matrix is twice the size of k's. A kernel bound by total work would take twice as long on twice the work. These match because each active SM chews through the same ~1024 rows either way (q: 16 heads across 2 CTAs, k: 8 heads across 1 CTA - eight heads per CTA both times). When doubling the work doesn't change the runtime, the limiter is how many workers showed up, not how much work exists. That's occupancy starvation, fingerprinted.
And you can't rescue it. Head-tiling is the only way to fuse this pair: in the fast warp-per-row GEMV, a head's 128 rows belong to 128 different warps, and the norm has to reduce across all of them. The norm needs head-locality; the GEMV needs row-parallelism; they can't have both. Fusing forces head-tiling, head-tiling forces starvation. The correct structure is to not fuse - and the round-trip the split pays turns out to be 4-6 KB of L2-resident data, close to free.
The rule I took from it: Fusion pays only when the fused unit still fills the machine. Count the parallelism after fusing. If it drops below the SM count, the round-trip you saved was cheaper than the SMs you lost. This rule gets confirmed twice more below.
Attention: getting it right, then getting it spread
I'll be honest about op9: it was the one I'd been circling. Eight ops in, everything so far had been a norm, a rotation, or a GEMV - and attention sat there with a softmax reduction in the middle, a GQA head mapping, and the online-softmax merge trick waiting at the end. I'd also set the bar deliberately high: I didn't want a correctness stub, I wanted the complete ladder - a naive version I understood line by line, then a real online softmax, as close to the FlashAttention family as M=1 decode allows. Ambition plus reputation was enough that the work slipped by a couple of days around this op. The way through was to refuse to be clever on the first pass: one warp per head, two-pass softmax, correctness before speed, evolve after.
And once it was on the page, it was smaller than its reputation. Decode attention at M=1 is GEMV-shaped too. Per head: q[128] . K[ctx,128] -> softmax over ctx -> weights . V[ctx,128] -> out[128] - a dot product per cached position, a softmax, a weighted sum. I'd already written five GEMVs by then. No tensor cores, nothing to tile - the question is purely how to spread a few hundred context positions across 84 SMs.
The simple version passed golden on the first build. It also put 16 warps on an 84-SM GPU - two CTAs busy, 82 idle. Correct and slow, by design.
The shipped version is split-KV. To be precise about the lineage: this is not FlashAttention - at M=1 there is no query block to tile and no tensor-core matmul, which is most of what FA is - but it takes the features of FA that survive at decode: tile the context across workers, keep a running max and exp-sum per tile, merge partials with the online-softmax rescale. One head per CTA; the CTA's 8 warps each take a contiguous chunk of the context and compute a partial attention over their slice - local max m, local exp-sum z, and an unnormalized accumulator. Then one intra-CTA merge folds the 8 partials with the online-softmax correction: a warp that computed against the wrong (local) max fixes itself with a single multiplier, exp(m_i - m_star), no recompute:
// stage 2: merge the 8 partials (online-softmax rescale)
if (threadIdx.x == 0) {
float mstar = -INFINITY;
for (int i = 0; i < nwarp; ++i) {
mstar = fmaxf(mstar, sm_m[i]);
}
float gz = 0.f;
for (int i = 0; i < nwarp; ++i) {
float s = __expf(sm_m[i] - mstar); // empty chunk: e^{-inf} = 0
sm_s[i] = s;
gz += sm_z[i] * s;
}
sm_gz = gz;
}
__syncthreads();
// out[d] = (sum_i acc_i[d] * s_i) / gz
if (threadIdx.x < HD) {
const int d = threadIdx.x;
float g = 0.f;
for (int i = 0; i < nwarp; ++i) {
g += sm_acc[i * HD + d] * sm_s[i];
}
attn_out[head * HD + d] = g / sm_gz;
}
The same factor corrects both the exp-sum and the accumulator, which is the elegance of the online-softmax trick - you divide by z exactly once, at the very end.
The correctness check for the merge is worth stealing. Split-KV is the same math as warp-per-head, just reorganized, so its output should match not just the golden but the error signature of the simple version. It did: the per-head relative errors matched exactly - worst head identical, best head identical. If the rescale were wrong, those per-head numbers would drift. Cross-checked at every context length, the two variants agree to ~1e-6, pure float-reassociation noise.
Speed. First a measurement lesson: wall-clock said split-KV was 4.6x faster at ctx=65, but a 4 us kernel is sitting on the launch-latency floor, so I took NCU's device-only times instead: 45.5 us vs 8.8 us, 5.2x. Then the context sweep:
| ctx | warp-per-head | split-KV | speedup |
|---|---|---|---|
| 65 | 18.6 us | 4.1 us | 4.6x |
| 256 | 69.6 us | 10.3 us | 6.7x |
| 1024 | 273 us | 36 us | 7.5x |
| 2048 | 547 us | 71 us | 7.7x |
| 4096 | cannot launch | 141 us | - |
Two things in that table. The speedup climbs toward the ideal 8x as context grows - at short ctx the merge tax and tail imbalance eat a chunk, at long ctx the serial K/V sweep dominates and they amortize away. And the last row is structural, not incremental: warp-per-head keeps all of its context's scores in shared memory per warp, 8·ctx floats, which crosses the 99 KB SMEM limit at ctx=4096 and simply cannot launch. Split-KV holds ctx/8 per warp and doesn't care. The reorganization isn't just faster - past a point, it's the only one of the two that runs.
Even so, honesty about what got won: the "fast" version runs at 2.3% SM throughput. It's 5x of nearly-nothing - op9 at decode is a latency-bound speck, and what matters is its ~9 us on the critical path of the fused layer, not its standalone occupancy.
Two optimizations that worked as designed and lost anyway
Flash-decoding. This is the top rung of the ladder I'd set out to climb - the last FA feature left to take at decode. Split-KV still pins attention to 16 CTAs (one per head), leaving 68 SMs idle. The fix is to split each head across N CTAs and merge the partials through global memory - same online-softmax correction, one level up. I benchmarked N = 2, 4, 8 at ctx 1024 to 8192:
| ctx | N=1 | N=2 | N=4 | N=8 |
|---|---|---|---|---|
| 1024 | 36.0 us | 1.73x | 2.98x | 4.29x |
| 2048 | 70.7 us | 1.83x | 3.35x | 5.21x |
| 4096 | 140.0 us | 1.89x | 3.59x | 5.83x |
| 8192 | 281.3 us | 1.92x | 3.63x | 5.82x |
It works, and the merge is nearly free - NCU shows the partials buffer never touches DRAM (it's at most 64 KB, fully L2-resident), so the cross-CTA merge costs a flat ~2 us regardless of ctx or N. The ceiling is SM count: 16 heads x N fills 84 SMs around N=5, and N=8 means 128 CTAs on 84 SMs, 1.5 waves, diminishing returns around 5.8x.
Then the regime check. This model transcribes 30-second audio chunks; real decode context lands around 400-470. The N=2 crossover is at ctx >= 1024. Below it, a flat 2 us merge tax on a 4-10 us op is a guaranteed regression. So flash-decoding is measured, understood, and not wired in. It goes in the pocket for a long-context regime this model doesn't run in.
GQA KV-sharing. With grouped-query attention, each pair of q-heads reads the same KV head's cache - in my split-KV, twice. Sharing the load means putting both q-heads of a group in one CTA: each K/V element gets read once and used for both heads' dots. The cost is visible up front: 16 CTAs become 8.
Measured at ctx=2048, device-only:
| metric | baseline (16 CTA) | shared (8 CTA) |
|---|---|---|
| KV bytes loaded | 134.5 MB | 67.4 MB |
| output vs baseline | - | bit-identical |
| GPU time | 168.4 us | 172.4 us |
| DRAM bytes read | 0 (L2-resident) | 0 (L2-resident) |
The optimization did exactly what it was designed to do - bytes halved to the megabyte, output bit-identical - and it's slower, at every context length I swept (0.88x to 0.95x). The DRAM row is the autopsy: at these sizes the whole KV cache is L2-resident, so the bytes I halved were L2 reads, and L2 bandwidth was never the bottleneck. What the sharing spent was CTA-level parallelism, the one resource op9 is actually starved for. I paid the scarce thing to save the free thing. Reverted, bench deleted.
One discipline note on the reverting. 0.95x is "only" 5% slower, and the mechanism was genuinely satisfying - it's tempting to keep it because it nearly breaks even and the code is clever. No. A regression is a regression; across a 17-op chain, small "harmless" ones stack, and the complexity never leaves. Anything that measures slower comes out, regardless of magnitude or elegance.
Flash-decoding and GQA-sharing are the same lesson from opposite directions. Profile first, read which resource is actually saturated, and pick the lever that frees that one. At M=1 decode, attention is starved for SM parallelism and swimming in cheap L2 bandwidth - so "spread across more CTAs" wins and "read fewer bytes" loses. At very long context or a bigger model, where KV spills out of L2 into DRAM, both verdicts flip. The levers aren't good or bad; they're matched to a regime or they aren't.
The real test: tokens, not tolerances
Per-op error gates are means, not ends. Per-op bf16 error of ~1e-2 across 28 layers could in principle flip an argmax, and no amount of per-op green tells you it didn't. So the final gate hooks vLLM's real audio decode - actual audio in, actual accumulated KV state, actual greedy sampling - captures its per-step state, runs the megakernel on the same state, and compares chosen tokens.
7/7 greedy tokens match. And not by luck: the winning logit's margin over the runner-up was 4.75 to 12.75 at every step - comfortable wins, not coin flips waiting for a rounding error. The accumulated drift flipped nothing.
At this point the megakernel works. One launch, 28 layers, correct transcript tokens. Then I timed it properly.
Part IV: the performance pass
Baselines first, measured cleanly and labeled, because this is where megakernel posts usually cheat:
| system | ms/token | % of DRAM peak | x over floor |
|---|---|---|---|
| weight-stream floor | 1.26 | 100% | 1.0x |
| vLLM CUDA graphs | 2.11 | 60% | 1.68x |
| megakernel v1 | 4.31 | 27% | 3.43x |
| vLLM eager | 25.04 | 5% | 19.9x |
Two true sentences from that table. Megakernel v1 is already ~6x faster than eager (25.04 / 4.31 = 5.8x; the finished kernel below ends at ~7x) - but eager is the 466-launch regime, the thing any fusion kills, so that's the flattering comparison. Against vLLM's CUDA-graph path - launch overhead already gone, kernels hand-tuned - v1 is 2x slower. Fusion deleted the launches, but 476 coarse grid.sync() barriers per token serialize everything the launches used to serialize, and at 27% of DRAM peak the weight stream has long gaps in it. The honest target isn't a headline multiple over eager; it's DRAM%. The graph path runs the same floor-limited problem at 60%.
So the optimization pass was about feeding the stream. Every change was gated on the 7/7 token match staying green.
Win 1: L2 prefetch during attention (4.31 -> 3.90 ms). Attention occupies CTAs 0-15; CTAs 16-83 idle through op9. That's the one window in the layer where HBM has slack, so the idle CTAs walk the upcoming weights and warm them into L2:
// idle CTAs (blockIdx >= NH) during op9: fire-and-forget L2 prefetch,
// no SMEM, no registers to speak of, no sync with the attention CTAs
void prefetch_l2(const bf16* W, long nelem) {
...
asm volatile("prefetch.global.L2 [%0];" :: "l"(W + e));
}
Two design points that came from measurement, not intuition. First, SMEM staging - the classic pipelining move - is a dead end here twice over: a layer's weights are 30 MB against 100 KB of SMEM per CTA, and at M=1 each weight byte is used exactly once, so there is nothing to reuse from a staging buffer anyway. L2 is the right and only target. Second, prefetching the next layer's qkv weights gained 1.7% - the lines got evicted before use. Prefetching this layer's post-attention weights (o_proj, gate/up, down - needed microseconds after the sync) gained 9.5%. Prefetch distance matters more than prefetch existence.
Win 2: killing barriers with redundant compute (3.90 -> 3.49 ms). The norm -> GEMV barriers existed because a norm wrote its result to global scratch for all CTAs to read. But every CTA already holds the full activation in its own SMEM - so each CTA just norms its own copy, in place:
for (int i = threadIdx.x; i < H; i += BLK) { s_act[i] = h_in[i]; }
__syncthreads();
op_rmsnorm_inplace(s_act, w.in_ln, s_red, H); // every CTA, redundantly
__syncthreads(); // block-local - no grid.sync
op_gemv(w.qkv, s_act, ...); // reads own SMEM, already correct
Eighty-four CTAs computing the same 1024-element RMSNorm is wasted flops by any classical measure, and it doesn't matter at all - the flops were free, the grid.sync() plus global round-trip were not. Two barriers per layer gone, zero races by construction.
Rejection number three: the glue fusion. Between wins 1 and 2 I tried fusing the whole norm -> rope -> cache-write glue strip into one warp-per-head pass, keeping each head register-resident through all three ops. It was correct - including a lane layout ({L, L+32, L+64, L+96}) that makes the rope's pair-exchange happen inside each lane, no shuffles at all - and it was 10% slower end-to-end. Twenty-four warps against a full grid, the qkv fusion story again at smaller scale. Reverted. Third confirmation of the same law, and this time the register choreography was genuinely pretty. The machine does not award style points.
The two wins compose into one sentence: removing a barrier pays when it costs no parallelism (norm-in-SMEM), and costs you the kernel when it does (glue fusion).
Final: 3.49 ms/token, 33.4% of DRAM peak, 7/7 tokens. 19% faster than v1, and still 1.65x behind CUDA graphs.
Where the remaining gap lives
I want to be precise about why I stopped here, because "I optimized until I stopped" is where posts usually go vague.
The safe wins are exhausted. What's left of the barrier count protects real data dependencies, and they exist because of a partition mismatch baked into design decision 3: GEMVs partition work by output row, but norms and elementwise ops want to partition by hidden-dim slice. Every time the chain switches views, someone needs everyone else's data, and that's a genuine global exchange - a barrier no local trick removes. The fix is a redesign, not a patch: a partitioning where each CTA owns the same hidden-slice through the entire chain, plus replacing grid.sync() with fine-grained producer-consumer signalling so layer N+1's weight stream can start while layer N's tail is still draining. That's where the remaining 27 points of DRAM% live. It's the next project, and it's the size of one.
Two effects to keep separate when reading the result, because conflating them is how megakernel numbers get oversold:
- The megakernel benefit - deleting launch, barrier, and round-trip fragmentation - is model-independent. That part worked and is measured above.
- The model-size effect is not. A 0.6B model makes the eager baseline maximally bad (tiny ops, worst-case launch overhead - a big relative win just sitting there), while its fixed 151936-token vocab makes the lm_head ~26% of the weight stream - a GEMV that's already grid-filled and gets nothing from fusion, capping the absolute ceiling near 3x. Small models flatter the before and starve the after.
And the production honesty. This kernel assumes a contiguous KV cache and batch=1. Real vLLM serving is paged KV with block tables (ops 8 and 9 would both need to speak paged), continuous batching across sequences of different lengths (the whole static work partition breaks), and custom-op plumbing. The kernel is the easy 20% - that machinery is most of why hand-written megakernels are rare in production. The defensible on-ramp is exactly what this model does for a living: batch=1, single-stream, greedy, low-latency ASR - the one regime where the launch-bound profile at the top of this post is the whole story and the paged/batched machinery mostly collapses away.
The levers still on the table
For completeness, the concrete list I'd work through in a v2, roughly in expected-payoff order:
- Unified hidden-slice partitioning. The redesign above - one work partition for the whole chain, killing the remaining genuine barriers. This is where most of the 33% -> 60% DRAM gap lives.
- Fine-grained producer-consumer signalling in place of
grid.sync(). A CTA that finished its slice of layer N's down-projection has no reason to wait before starting layer N+1's qkv rows that only depend on data it already holds. This turns the prefetch hack into real cross-layer overlap. - Fused greedy argmax in the lm_head epilogue. The 151936-logit vector currently round-trips through global memory just to be argmax'd. Greedy decode needs one integer.
- DSMEM via thread-block clusters. SM120 supports clusters of up to 8 CTAs; intra-cluster handoffs could skip even L2. Marginal next to 1 and 2, but it stacks.
- Flash-decoding, wired behind a ctx threshold - already measured, activates if a long-form audio mode ever pushes decode ctx past ~1024.
- Small-batch M=2-8. The weight stream is the cost and it's identical for 1 or 8 tokens; a second sequence rides along almost free until the glue ops stop being negligible. This is also the honest answer to "what if I have a little batching" - the megakernel doesn't have to mean batch=1 forever.
These are roadmap items, not promises of a follow-up post on each.
Where a megakernel actually makes sense
This is the section I wish someone had written before I started - the decision logic, separated from my particular model and GPU.
The profile decides, not the enthusiasm. The qualifying signal is the one at the top of this post: GPU-busy a small fraction of wall-clock, most ops under one wave per SM, glue ops stalled on long_scoreboard re-reading what the previous op just wrote. That profile says the machine is idling between ops, and only fusion reclaims that. If your profile instead shows big grids and a pipe near saturation - prefill, training, large-batch serving - a megakernel buys almost nothing, because there's no idle time to reclaim.
Try CUDA graphs first. Graphs are a config flag and they delete launch overhead completely - that's the free 80%, and this post's own numbers show a tuned graph path is hard to beat. What graphs cannot do is remove the HBM round-trips between dependent ops or the implicit barrier at every kernel boundary; kernels in a graph are still separate kernels. The megakernel is the (expensive) instrument for exactly those two things - if graphs already put you near your weight-stream floor, stop there.
The regime where it pays is the intersection: small model, batch~1, latency as the product. The smaller the model, the larger the fixed per-op costs loom relative to useful work - a 4 us kernel tax matters at 2.2 ms of GPU work per token and vanishes at 50. Concretely that's low-latency ASR and TTS, speculative-decoding draft models (small by construction, decode-only, latency-critical by definition), and real-time control loops on edge GPUs. It's not a coincidence that this project's model is one of them.
Where it doesn't pay: big models at decode (the weight stream saturates DRAM anyway - graphs get you to the floor), prefill and training (tile regime, tensor cores, occupancy already fine), and batched throughput serving, where continuous batching is the product and a static 84-CTA partition is the wrong shape entirely.
And the cost nobody prices in: a hand-written megakernel freezes the architecture. This kernel knows it has 28 layers, hidden 1024, 16/8 heads, SiLU. A model revision upstream is not a config change here; it's surgery. So the approach fits where the model is stable and the latency win is worth owning the code - and doesn't fit a research loop where the architecture changes monthly.
One sentence version: profile first, graphs second, megakernel only when the profile says launch-bound and the regime says small-model latency - which is rarer than the hype and exactly as real as this post's 7x-over-eager measurement when it applies.
Closing
What I'd actually carry to the next kernel, in one place:
Fusion is a bet that the fused unit still fills the machine - at M=1 a head is too coarse a unit, and I lost that bet three times (19x, 10%, and a pretty one). The scarce resource decides everything: attention here was starved for SMs and swimming in free L2 bandwidth, so spreading won and byte-saving lost, and both verdicts would flip in a different regime. bf16 drift saturates instead of compounding - measured, not assumed, and it saved me a precision migration. And per-op goldens made correctness the cheap part of the project, which is the opposite of the usual megakernel experience.
The through-line is older than any of those: every fork in this project was decided by a measurement, and four of my better-sounding ideas are dead because of it. 3.49 ms stands, 2.11 ms is the number to beat, and the redesign that might get there is scoped. That's the honest state of it.