Emre's Blog

From 429 GB/s to the DRAM wall: writing an FP8 quantizer on an RTX 5080

This is the story of one CUDA kernel pair in cublade, my personal kernel library. The kernels do per-tensor symmetric FP8 (E4M3) quantization and dequantization. They started at 429 GB/s on a v1 build, then NCU shoved them to roughly 880 GB/s at the kernel level - about 98.5% of the RTX 5080's real DRAM peak. Plus a dequantizer that landed DRAM-bound on its first compile.

I'm going to walk through it the way it happened. v1 from first principles, then the NCU-driven optimization pass, then the mirror-image dequantizer. There are a couple of teaching moments and one place where I did pedagogy wrong and had to recover. I'm including those.

The repo: github.com/emre570/cublade. Code paths in this post are real and resolve in the repo.

What FP8 is, briefly

FP8 E4M3 is an 8-bit float: 1 sign, 4 exponent, 3 mantissa. The dynamic range covers roughly +/- 448 with a granularity that gets coarser as you move away from zero. The cheap way to use it is per-tensor symmetric quantization:

  1. Find amax = max(|x|) over the whole tensor.
  2. Compute scale = 448 / amax.
  3. Write q[i] = round_to_fp8(x[i] * scale) and store scale_recip = amax / 448.

To recover floats: y[i] = q[i] * scale_recip. The numbers don't come back exactly - FP8 only has 256 distinct values - but they come back close enough for inference weights and activations in modern transformers.

Per-tensor is the simplest scheme. There are fancier ones (per-channel, per-group, block-scaled MXFP4/NVFP4) that I'm building in cublade as separate modules. This post is about the simplest one done well.

The structure forces two kernels:

amax_reduce      : x[n] -> amax (scalar)
quantize_cast    : x[n], amax -> q[n], scale_recip

The reduction has to finish before the cast can run. They're two grid launches with a global barrier between them.

Part I: v1 from first principles

I'm going to spend a few sections on v1 because it's where the warp-level primitives get built. If you already know warp shuffles by heart, you can skip to Part II.

Warp shuffle from scratch

Each block has up to 256 threads, organized as 8 warps of 32 threads. To find the max across all 256, you reduce inside each warp first (warp shuffle), then across warps (shared memory), then across blocks (atomic).

Let me start small to make the warp part concrete. Say warp size is 8 threads (it's actually 32, the picture is the same) and each one holds one number:

thread:  0   1   2   3   4   5   6   7
value:   1   2   3   4   5   6   7   8

I want every thread to end up holding 8 (the max). The classic primitive is __shfl_xor_sync(mask, val, offset): thread i exchanges its value with thread i ^ offset. Then both sides take the max. Three rounds with offsets 4, 2, 1 are enough to give every thread the warp-wide max:

Warp shuffle butterfly: three rounds of XOR-pair reductions take eight threads from [1..8] to all 8s.

Three rounds, log2(8). For a real 32-thread warp it's five rounds with offsets 16, 8, 4, 2, 1. Code:

float my_val = ...;  // each thread's value
for (int offset = 16; offset > 0; offset >>= 1) {
    float v = __shfl_xor_sync(0xffffffff, my_val, offset);
    my_val = fmaxf(my_val, v);
}
// every thread in the warp now holds the warp-wide max

^ is XOR. The reason it works: at each step, i ^ offset is the partner thread that holds the half of the data you don't have yet. After log2(N) steps every thread has seen every value.

Teaching moment: how not to explain this

The first time Claude explain warp shuffle in conversation, it led with a four-level hierarchy: "Level 0 is thread, Level 1 is warp via shfl_xor in PTX, Level 2 is block via shared memory, Level 3 is grid via atomic, and..."

Then I said "I did not understand any shit."

Fair. Hierarchies and PTX vocabulary don't help me understand butterfly reductions, yet. Concrete numbers with [1, 2, 3, 4, 5, 6, 7, 8] and three little ASCII diagrams do. That's how I understand above.

Up the layers: SMEM and atomic

After the warp shuffle, every thread in the warp holds the warp-wide max. We only need one thread per warp to publish it. Thread 0 of each warp writes to shared memory:

__shared__ float warp_max[8];          // 8 warps per block at BLOCK=256
int lane    = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
if (lane == 0) warp_max[warp_id] = my_val;
__syncthreads();

Eight values sitting in shared memory. We need to reduce them to one. The first warp does it: each of its first 8 lanes picks up one slot, the rest read zero, and we run the same butterfly with offsets 4, 2, 1:

if (warp_id == 0) {
    my_val = (lane < 8) ? warp_max[lane] : 0.0f;
    for (int offset = 4; offset > 0; offset >>= 1) {
        float v = __shfl_xor_sync(0xffffffff, my_val, offset);
        my_val = fmaxf(my_val, v);
    }
    if (lane == 0) atomicMaxFloat(amax_buf, my_val);
}

atomicMaxFloat doesn't exist as a CUDA intrinsic. Floats don't have an atomic max. The trick is to atomic-CAS on the bit-cast int:

__device__ float atomicMaxFloat(float* addr, float val) {
    int* addr_i = (int*)addr;
    int old = *addr_i, assumed;
    do {
        assumed = old;
        if (__int_as_float(assumed) >= val) break;
        old = atomicCAS(addr_i, assumed, __float_as_int(val));
    } while (assumed != old);
    return __int_as_float(old);
}

It works because non-negative IEEE-754 floats sort the same as their int bit pattern. amax values are non-negative (we took the abs), so this is safe.

For the analogy: atomicCAS is the way one thread out of many gets to update a global slot. Picture a hotel lobby with one shared pen. Each block sends a runner with "my warp-wide max" written on a slip. They line up at the desk. The clerk reads the slip and the current best. If the slip is bigger, the clerk replaces what's on the wall and the runner walks away. If not, the runner walks away anyway. Nobody ever overwrites a bigger number with a smaller one.

Reading production code as reference

Before writing the kernels I read NVIDIA TransformerEngine's current_scaling.cu - the same per-tensor amax pattern in production transformer-engine code. About 40 lines that matter. The shape was:

load chunk -> compute thread-local amax -> warp shuffle reduce
   -> SMEM noticeboard -> first warp final reduce -> atomicMax

Side-by-side adaptation table, TFEngine vs cublade:

step TFEngine cublade v1
load scalar with manual unroll scalar (1 elem/thread)
warp __shfl_xor_sync log2(32) same
smem __shared__ float warp_max[8] same, fixed at 8 (BLOCK=256)
sync __syncthreads same
final first warp 8-way reduce same
atomic atomicMaxFloat via CAS same

The thing I changed: TFEngine's launch sizing is dynamic. I hardcoded BLOCK=256 because (a) it gives 8 warps which is a power of 2 (clean final reduce), (b) the SM120 register file at BLOCK=256 fits 6 blocks per SM (good occupancy), © one launch param is one less knob to worry about during the first build, and (d) I only have one hardware target.

Writing the kernel

I wrote amax_reduce_kernel first - about 25 lines: thread-local max via the __habs + fmaxf pair, then the warp-shuffle butterfly above, then SMEM noticeboard, then atomicMaxFloat over blocks.

10 unit tests pass against a CPU reference: deterministic inputs, explicit thresholds, all green.

The cast kernel is shorter, about 10 lines:

__global__ void quantize_cast_fp8(
    const half* __restrict__ x,
    const float* __restrict__ amax_buf,
    __nv_fp8_storage_t* __restrict__ q,
    float* __restrict__ scale_out,
    int n, float eps)
{
    float scale = 448.0f / fmaxf(*amax_buf, eps);
    if (blockIdx.x == 0 && threadIdx.x == 0)
        *scale_out = 1.0f / scale;

    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i >= n) return;

    float v = __half2float(x[i]) * scale;
    q[i] = __nv_cvt_float_to_fp8(v, __NV_SATFINITE, __NV_E4M3);
}

Each thread reads one half, multiplies in FP32, casts to FP8 via the intrinsic. SATFINITE clamps overflow to 448 instead of producing FP8 NaN.

The output scale_out is the reciprocal: amax / 448. That's what dequant will multiply by later. Writing the reciprocal in cast, not in dequant, avoids a divide on every dequant call.

Bit-exact contract from day one

The C++ test harness has a CPU reference:

static std::vector<__nv_fp8_storage_t>
cpu_quantize(const std::vector<half>& x, float amax) {
    float scale = 448.0f / std::max(amax, 1e-12f);
    std::vector<__nv_fp8_storage_t> q(x.size());
    for (size_t i = 0; i < x.size(); ++i) {
        float v = __half2float(x[i]) * scale;
        __nv_fp8_e4m3 fp8(v);
        q[i] = fp8.__x;
    }
    return q;
}

Five tests at the start: toy 8-element fixed input, n=1023 (odd, random), n=2048, n=1M random, n=1M with one outlier near 448 to stress saturation. The check is q_gpu[i] == q_cpu[i] as raw bytes, not "within tolerance". E4M3 is a discrete grid; tolerance-based checks would let single-lane bugs slip past.

All five pass.

v1 bench

The 4-way bench across CUDA, the Python cublade FP8 path, the Python cublade INT8 path (legacy torch-math fallback at this point), and torch.ops.aten._scaled_mm's underlying FP8 quant:

n CUDA (ms) cublade FP8 (ms) cublade INT8 (ms) torch raw FP8 (ms) CUDA GB/s
4,096 0.025 0.203 0.252 0.159 0.5
65,536 0.027 0.205 0.247 0.135 7.2
1,000,000 0.026 0.211 0.256 0.151 114.0
16,000,000 0.112 0.526 0.241 0.524 429.1
100,000,000 1.026 3.998 2.882 3.997 292.5

429 GB/s at n=16M, about 45% of the 5080's 960 GB/s DRAM spec. 3-10x faster than every Python path. Small-N is launch-overhead-dominated; the CUDA times bottom out around 22 us because kernel launch is ~10 us each and there are two launches.

I was happy with this for about an hour.

Part II: NCU shoves us to the DRAM wall

After the v1 was working I sat with the numbers. 429 GB/s is fine. It's also a long way from 960. The question is: where are the missing 530 GB/s going? You can't answer that by staring at the code. You have to measure.

Before pulling up NCU, a roofline sanity check. Per-tensor quantization moves roughly 2 bytes in (a half) and 1 byte out (an FP8 byte), with about 5 flops in between (abs, max, multiply, conversion). The arithmetic intensity is well under 1 flop per byte. On SM120 - with FP32 throughput in the tens of TFLOP/s and DRAM at ~960 GB/s - the roofline says this kernel is firmly in the memory-bound region. Compute is not the wall, bandwidth is. So 429/960 = 45% of DRAM peak means there is roughly 55% headroom left, and any recovery has to come from the memory subsystem. That is exactly what NCU is good at telling you about.

That's when I ran NCU on v1.

The 9-column dump that flopped

NCU --set detailed produces a wall of metrics per kernel: DRAM throughput, L1 hit rate, L2 hit rate, issue rate, achieved occupancy, warp stall reasons, register pressure, shared memory usage, and so on. Claude's first attempt at explaining the v1 result was to dump nine columns in a table:

kernel @ N DRAM % L1/TEX % DRAM GB/s L1 hit Occ IPC Duration
amax @ 16M 32.8% 44.3% 310 5.85% 70.9% 0.94 103 us
cast @ 16M 43.3% 33.8% 409 43.75% 74.7% 1.05 78 us

It didn't land. I asked, again: "I did not understand the two bottlenecks properly. Can you simplify it please?"

So I threw the table out and drew a pipe diagram:

Two pipes between threads and DRAM: an issue pipe from THREADS to L1 cache (how fast threads can ask for bytes), and a delivery pipe from L1 to DRAM (how fast DRAM can hand them over). The fullest pipe is the bottleneck.

Two pipes only. The issue pipe is "how fast can threads ask for bytes." The delivery pipe is "how fast can DRAM hand them over." Whichever is more full is the bottleneck.

Reduced to two numbers per kernel:

kernel issue pipe busy DRAM busy who's slow?
amax 44% 33% issue pipe (DRAM has spare capacity)
cast 31% 56% DRAM (threads asking faster than DRAM delivers)

Two kernels, two different slow points.

amax is bound on issue. Each thread loads one half (2 bytes), does a tiny amount of work, then has to launch another load. The DRAM controller is sitting half-idle waiting for the threads to ask for more.

cast is bound on DRAM, but only partially - and it has a hidden cheat. cast's L1 hit rate is 44%, meaning almost half its loads are served from L1, not DRAM. Why? Because cast ran right after amax on the same tensor; chunks of x are still sitting in L1 from the amax pass. So cast's effective "DRAM bandwidth" is inflated by free L1 reuse.

Either way: both kernels are under-utilizing DRAM. The path to more throughput is to ask DRAM for more bytes per instruction.

The vectorization plan, each choice tied to data

The fix is wide loads. Each thread takes 8 elements via one 128-bit int4 load, instead of one element per load. Concretely:

  1. VEC = 8, BLOCK = 256. Each thread owns 8 elements; each block owns 2048.
  2. Vector load via __ldg(reinterpret_cast<const int4*>(x + base)). __ldg hints read-only-cache; the compiler emits LDG.E.128.
  3. Type punning via a union Pack16 { int4 raw; half2 h2[4]; half h[8]; }. Safer than raw reinterpret_cast aliasing. Same SASS.
  4. FP32 multiply path required. Tempting to multiply in half2 and save instructions; doesn't work. The Python reference rounds in FP32 (the chain is half -> float -> mul -> fp8), and we have a torch.equal gate at every bench point. half2 multiply rounds differently. Stay in FP32.
  5. Packed conversion __nv_cvt_float2_to_fp8x2. Two FP32s in, two FP8 bytes out per call. Four calls per thread, packed into one uint64 store.
  6. Scalar fallback for the tail. The block whose base + VEC > n falls back to one-element-per-iteration. Threads whose base >= n contribute 0 to amax, which is safe because amax is over non-negative values.
  7. int64_t indexing throughout. n * sizeof(half) overflows int32_t past about 2 GB.
  8. No __launch_bounds__ preemptively. Measure regs/thread first, only add the directive if it crosses ~42 (where Block Limit Registers becomes binding on SM120).

The v2 amax kernel (just the per-thread phase)

Phases 2+ (warp shuffle, SMEM, atomic) are byte-for-byte v1. The only new code is per-thread:

const int64_t base = (int64_t)blockIdx.x * blockDim.x * VEC
                   + (int64_t)threadIdx.x * VEC;
float my_val = 0.0f;

if (base + VEC <= n) {
    Pack16 pk;
    pk.raw = __ldg(reinterpret_cast<const int4*>(x + base));
    half2 a0 = __habs2(pk.h2[0]);
    half2 a1 = __habs2(pk.h2[1]);
    half2 a2 = __habs2(pk.h2[2]);
    half2 a3 = __habs2(pk.h2[3]);
    half2 m01 = __hmax2(a0, a1);
    half2 m23 = __hmax2(a2, a3);
    half2 m   = __hmax2(m01, m23);
    half  lm  = __hmax(__low2half(m), __high2half(m));
    my_val    = __half2float(lm);
} else {
    int64_t end = (base + VEC < n) ? (base + VEC) : n;
    for (int64_t k = base; k < end; ++k) {
        my_val = fmaxf(my_val, __half2float(__habs(x[k])));
    }
}
// ... warp shuffle / SMEM / atomic unchanged.

8 halves load -> 4 half2 abs -> 2 max pairs -> 1 max pair -> 1 scalar max -> 1 float. Fully unrolled, no loop.

The v2 cast kernel

const int64_t base = (int64_t)blockIdx.x * blockDim.x * VEC
                   + (int64_t)threadIdx.x * VEC;

if (base + VEC <= n) {
    Pack16 pk;
    pk.raw = __ldg(reinterpret_cast<const int4*>(x + base));

    __nv_fp8x2_storage_t p[4];
    #pragma unroll
    for (int k = 0; k < 4; ++k) {
        float2 f2 = __half22float2(pk.h2[k]);
        f2.x *= scale;
        f2.y *= scale;
        p[k] = __nv_cvt_float2_to_fp8x2(f2, __NV_SATFINITE, __NV_E4M3);
    }
    uint64_t packed = (uint64_t)p[0]
                    | ((uint64_t)p[1] << 16)
                    | ((uint64_t)p[2] << 32)
                    | ((uint64_t)p[3] << 48);
    *reinterpret_cast<uint64_t*>(q + base) = packed;
} else {
    int64_t end = (base + VEC < n) ? (base + VEC) : n;
    for (int64_t k = base; k < end; ++k) {
        float v = __half2float(x[k]) * scale;
        q[k] = __nv_cvt_float_to_fp8(v, __NV_SATFINITE, __NV_E4M3);
    }
}

Load 8 halves in one instruction. Widen each half2 to float2. Multiply pair-wise in FP32. Pack to fp8x2. OR four uint16s into one uint64. Store in one instruction. 24 bytes of memory traffic per thread, 1 LDG, 1 STG.

New edge-case tests

The v2 kernel partitions on base + VEC <= n. That introduces two new failure modes that v1 didn't have:

Both got added. Plus the 10 v1 tests, untouched. 14/14 pass on v2.

Compile check

nvcc -Xptxas=-v reports:

amax_reduce_kernel_v2 : 29 regs/thread, 0 spills, 0 stack
quantize_cast_fp8_v2  : 22 regs/thread, 0 spills, 0 stack

Both well under the ~42 threshold where register pressure would start limiting block count. SM120 has 65,536 registers per SM. At BLOCK=256, 6 blocks per SM warp cap, the worst case is 29 * 256 * 6 = 44,544 registers. Fits with room. No __launch_bounds__ needed.

v2 bench

n v1 (ms) v2 (ms) speedup v1 GB/s v2 GB/s equality
4,096 0.0221 0.0234 0.94x 0.6 0.5 v1 == v2 == py
65,536 0.0227 0.0236 0.96x 8.6 8.3 v1 == v2 == py
1,000,000 0.0222 0.0232 0.96x 135.1 129.1 v1 == v2 == py
16,000,000 0.1114 0.0229 4.87x* 430.8 2096.5* v1 == v2 == py
100,000,000 1.0231 0.5727 1.79x 293.2 523.8 v1 == v2 == py

*The n=16M row is L2-resident, not DRAM-bound. 16M halves = 32 MB, the 5080's L2 is ~65 MB, so the tensor never leaves cache across iterations. See Disclaimer 1 below. The n=100M row is the honest DRAM-bound number.

Bit-exact at every size. 4.87x at n=16M. 1.79x at n=100M. Small-N is launch-overhead-floor at ~22 us; vectorization can't fix that.

Two of those numbers need disclaimers.

Disclaimer 1: the n=16M number is not what it looks like

The bench says v2 hits 2096 GB/s at n=16M. The 5080's DRAM spec is 960 GB/s. 2096 > 960. That should be a red flag.

It is. A 16M-element half tensor is 32 MB. The 5080's L2 cache is about 65 MB. Across 100 bench iterations the tensor stays resident in L2 - only iteration 1 pays full DRAM cost. So 2096 GB/s is L2 bandwidth, not DRAM bandwidth.

I'm leaving that number in the table because real production workloads with hot tensors will see the same L2 benefit. It's not fake. But it is not the DRAM-bound number.

The n=100M number (524 GB/s) is the honest DRAM-bound number. 100M halves is 200 MB, well past L2. That's the one to quote when you want to know how fast the kernel can be when the data has to come from DRAM.

v2 NCU: the bottleneck moved

ncu --set detailed --target-processes all \
    --kernel-name "regex:(amax_reduce_kernel_v2|quantize_cast_fp8_v2)" \
    uv run python profile_v2.py \
    > ncu_v2.txt

Side-by-side with v1:

kernel @ N v1 DRAM % v2 DRAM % v1 GB/s v2 GB/s speedup
amax @ 16M 32.8% 86.3% 310 816 2.63x
cast @ 16M 43.3% 87.2% 409 824 2.00x
amax @ 100M 33.9% 92.8% 321 878 2.74x
cast @ 100M 55.8% 91.0% 528 861 1.72x

Both kernels are now DRAM-bound at 87-93% of peak. The single-pass per-tensor FP8 quantizer is essentially at the hardware ceiling.

Two-pipes view for v2:

kernel issue pipe busy DRAM busy who's slow?
amax v2 @100M 15% 93% DRAM (hardware wall)
cast v2 @100M 13% 91% DRAM (hardware wall)

Two pipes between threads and DRAM. v1 (top): issue pipe 44% full, DRAM 33% full - software-bound. v2 (bottom): issue pipe 15%, DRAM 93% - physics-bound.

The issue pipe on amax dropped from 44% to 15%. DRAM took the headroom and ran with it. That's exactly the shape you want: the bottleneck migrated from a software-fixable place to a physics-bound place.

Disclaimer 2: cast lost its 44% L1 cheat (and that's fine)

v1 cast had a 44% L1 hit rate because it ran on the same tensor amax had just streamed. v2 cast has a 4% L1 hit rate. Why? Because the v2 kernels finish so fast there's no time for amax's residue in L1 to be useful to cast.

In other words, v1's 528 GB/s effective at cast was DRAM bandwidth plus free L1 reuse. v2's 861 GB/s is pure DRAM. The "2x speedup" understates the actual memory-system efficiency gain.

This took me a minute to sit with. The intuition is: when both kernels were slow, cast benefited from L1 staleness in a way that doesn't survive once both kernels are fast. The cache crutch goes away and the real numbers come up.

The IPC paradox

v1 IPC: 0.94 to 1.05. v2 IPC: 0.39 to 0.49. Lower.

If you grew up reading CPU profilers, low IPC is a regression. The first reflex is to chase it. I didn't, and here's why.

Look at what vectorization did to the instruction stream. v1 per thread: 1 narrow load, 1 cvt, 1 narrow store - 3 work-instructions for 1 element. v2 per thread: 1 wide int4 load, 4 packed __nv_cvt_float2_to_fp8x2 calls, 1 wide uint64 store - 6 work-instructions for 8 elements. Instructions per element dropped by roughly 4x. The kernel got faster too, but not 4x faster - so the numerator of IPC (instructions) shrank more than the denominator (cycles). Smaller numerator over similar denominator gives a smaller IPC by construction.

A lower IPC here is the signature of a better kernel: each instruction does more work, and the ALUs sit idle waiting on memory because that's what memory-bound looks like. IPC stops being a meaningful health signal the moment DRAM is the wall.

The simple rule: low IPC + low DRAM % means you have headroom to fix. Low IPC + high DRAM % means you're done.

The 946 vs 960 finding

NCU's dram__throughput.avg.pct_of_peak_sustained_elapsed reports against the GPU's computed peak. At 92.77% pct-of-peak we measure 878 GB/s, so the implied real peak is 878 / 0.9277 = 946.4 GB/s.

The spec says 960 GB/s. The achievable peak is 98.5% of the listed spec.

So when sizing future SM120 elementwise kernels, the realistic ceiling for memory-bound work is around 880 GB/s, not 960. The remaining ~70 GB/s is the gap between sustained and theoretical that GDDR controllers always leave on the table. I now hold ~880 in my head as the "good enough" target for any SM120 single-pass elementwise kernel.

Bench versus NCU reconciliation

The bench says 524 GB/s at n=100M. NCU says 861-878 GB/s per kernel at the same size. Both right; they measure different things.

Both belong in the post. Bench is what users feel. NCU is what the hardware does. The 1.79x end-to-end speedup is what ships to the README. The kernel-level 92% DRAM busy is the engineering health metric.

Part III: dequant is the easy half

Quant was the hard direction. Dequant is the inverse, and strictly easier: no reduction, no atomic, single pass, pure elementwise.

The full FP8 round trip with all three kernels chained on-device:

FP8 round trip pipeline: half input flows through amax_reduce, quantize_cast, and dequantize back to half output. The scale buffer (amax/448) is written by quantize_cast and read by dequantize, threading device-side state without ever returning to the host.

Same scale pointer chains all three. Device-side flow, no host round-trip.

The mirror, byte for byte

cast v2 and dequant v2 are structural opposites:

step cast v2 dequant v2
load int4 of 8 halves (16 B) uint64 of 8 FP8 bytes (8 B)
convert 4x __half22float2 -> mul -> cvt_float2_to_fp8x2 4x cvt_fp8x2_to_halfraw2 -> mul -> float22half2_rn
store 1x uint64 of 8 FP8 bytes (8 B) int4 of 8 halves (16 B)
total 24 B/thread (16 in + 8 out) 24 B/thread (8 in + 16 out)

Same VEC=8, same BLOCK=256, same Pack16 union, same FP32 multiply path. Only the load/store widths swap.

The one new intrinsic

The inverse of __nv_cvt_float2_to_fp8x2 is __nv_cvt_fp8x2_to_halfraw2(uint16_t pair, __NV_E4M3), returning a __half2_raw. One-line helper:

__device__ __forceinline__ half2 fp8x2_to_half2(uint16_t pair) {
    __half2_raw hr = __nv_cvt_fp8x2_to_halfraw2(pair, __NV_E4M3);
    return half2(hr);
}

half2(__half2_raw) is a CUDA 11.4+ constructor. On the 13.0 toolkit it compiles clean. If it ever breaks, the fallback is *reinterpret_cast<half2*>(&hr). Same SASS either way.

The kernel body

__global__ void dequantize_fp8(
    const __nv_fp8_storage_t* __restrict__ q,
    const float* __restrict__ scale,
    half* __restrict__ y,
    int64_t n)
{
    const float s = __ldg(scale);

    const int64_t base = (int64_t)blockIdx.x * blockDim.x * VEC
                       + (int64_t)threadIdx.x * VEC;

    if (base + VEC <= n) {
        const uint64_t packed =
            __ldg(reinterpret_cast<const uint64_t*>(q + base));

        uint16_t p[4];
        p[0] = (uint16_t)( packed        & 0xFFFFu);
        p[1] = (uint16_t)((packed >> 16) & 0xFFFFu);
        p[2] = (uint16_t)((packed >> 32) & 0xFFFFu);
        p[3] = (uint16_t)((packed >> 48) & 0xFFFFu);

        Pack16 pk;
        #pragma unroll
        for (int k = 0; k < 4; ++k) {
            half2  h2 = fp8x2_to_half2(p[k]);
            float2 f2 = __half22float2(h2);
            f2.x *= s;
            f2.y *= s;
            pk.h2[k] = __float22half2_rn(f2);
        }

        *reinterpret_cast<int4*>(y + base) = pk.raw;
    } else {
        int64_t end = (base + VEC < n) ? (base + VEC) : n;
        for (int64_t k = base; k < end; ++k) {
            __nv_fp8_e4m3 fp8; fp8.__x = q[k];
            float v = static_cast<float>(fp8) * s;
            y[k] = __float2half_rn(v);
        }
    }
}

Compile: 22 regs, 0 spills, 0 stack. Exact peer to quantize_cast_fp8 (also 22).

The bit-exact gate (stronger than round-trip)

Quant has an obvious correctness gate: round-trip the data, compare under FP8 tolerance. Dequant has an option to do something stronger: compare the kernel output to a CPU reference at the raw 16-bit level, not within tolerance.

CPU reference:

static std::vector<half> cpu_dequant(
    const std::vector<__nv_fp8_storage_t>& q, float s) {
    std::vector<half> y(q.size());
    for (size_t i = 0; i < q.size(); ++i) {
        __nv_fp8_e4m3 fp8; fp8.__x = q[i];
        float f = static_cast<float>(fp8) * s;
        y[i] = __float2half_rn(f);
    }
    return y;
}

Why this works: the kernel narrows via __float22half2_rn and the CPU reference narrows via __float2half_rn. Both are round-to-nearest-even on the same float input. They produce bit-identical halves, lane by lane.

The comparison loop just checks raw bits:

uint16_t a = *reinterpret_cast<const uint16_t*>(&got[i]);
uint16_t b = *reinterpret_cast<const uint16_t*>(&ref[i]);
if (a != b) ++mismatches;

Direct dequant test cases:

  1. Toy {-3..4}, s=1.0. E4M3 represents -3..4 exactly. Lane-by-lane test.
  2. All zeros, n=1000. Trivial.
  3. Alternating +/-448, s=0.1, n=100. Hits the FP8 saturation boundary, output halves are +/-44.8.
  4. n=2049 structured. Exercises the scalar tail with valid (no-NaN) FP8 bytes.

All four: zero mismatches at the 16-bit level.

Plus 7 round-trip tests through the full amax -> cast -> dequant pipeline (toy, 1M random, all-zeros, single-outlier, n=1023, n=2048, n=2049) under tolerance gates. All pass. Plus the 14 v2 quant tests, untouched. 25/25.

NCU: DRAM-bound on the first compile

ncu --set detailed --target-processes all \
    --kernel-name "regex:dequant_fp8_kernel_v2" \
    ./profile_dequant_v2 \
    > ncu_dequant_v2.txt

Two-pipes:

kernel @ N issue pipe busy DRAM busy DRAM GB/s who's slow?
dequant v2 @ 16M 20% 91.8% 866 DRAM wall
dequant v2 @ 100M 8% 85.3% 807 DRAM wall

Against the 5080's measured 946 GB/s real peak: 91.5% at 16M, 85.3% at 100M. DRAM-bound on the first compile, no further iteration needed.

Unlike cast v2 at n=16M (which inherited ~4% L1 hit from amax's residue in v2, or 44% in v1), dequant has no upstream kernel sharing its working set in this harness. Both n=16M and n=100M numbers are honest DRAM-bound, no L2-resident artifact to disclaim.

Three kernels side-by-side: the read/write asymmetry

kernel @ 100M DRAM % GB/s IPC Occ
amax v2 92.8% 878 0.49 71%
cast v2 91.0% 861 0.39 77%
dequant v2 85.3% 807 0.32 82%

Dequant lands a notch behind cast and amax. The pattern shows up if you sort by byte direction:

The 5080's DRAM subsystem favors read traffic over write traffic at peak, and there are concrete mechanisms behind that. GDDR is a half-duplex bus: switching between read and write forces the controller to flip drive direction (bus turnaround), and each flip burns cycles where no useful traffic flows. Read-heavy kernels like amax spend long stretches pointing the bus one way; write-dominant kernels like dequant force more flips per byte moved. On top of that, modern caches are write-back: a logical store that misses L2 typically allocates the line first (one DRAM read) and writes it back later on eviction (one DRAM write), so the DRAM bytes per logical store can be larger than the bytes the kernel actually wrote. Read traffic does not pay this tax. Together these effects leave a few percent of peak DRAM on the table for write-heavy kernels by design - not something a single-pass kernel can recover. The next byte-reduction win is fusion: combine dequant with the consumer (e.g. a GEMM epilogue) so the half output is consumed in-register and never round-trips through DRAM. That's Module 4 territory in cublade, not this post.

Closing notes

A few things that didn't fit cleanly into the narrative:

Pass C (BF16 default). The production code in quantize_per_tensor_fp8.cu is templated on T in {half, __nv_bfloat16} with shared dtype primitives in fp8_pack.cuh. BF16 is the default; FP16 is retained as legacy. The reduction routes through float for both dtypes - __hmax2 and __habs2 are half2-only, and the swap costs a few fmaxf instructions at 92% DRAM busy, which is in the noise. BF16 vs FP16 NCU numbers were within 0.2% of each other.

The autotune attempt that didn't land. I tried a 9-variant BLOCK x ITEMS autotuner. Bench at n=100M showed variant spread of 1-2.6% on every kernel/dtype combo, with winners flipping between b128_i1, b256_i1, b512_i1 across runs of the same kernel template. GPU-state noise level, not signal. Removed it. The kernels are at the DRAM roofline; there's no autotune surface worth exposing. The lesson: don't reach for an autotuner on DRAM-bound elementwise kernels. Build one alongside a GEMM where tile-shape sweeps are the real consumer.

Roadmap. The cublade roadmap from here is block-scaled quant first (MXFP4 and NVFP4 - block-scaled, no cuBLAS competitor on SM120 yet), then fused RMSNorm-then-quant for the hot path, and eventually SM120 FP8 GEMM with the scale + amax + GELU EVT epilogue. These are roadmap items, not promises of a follow-up post on each.

The arc from "where do the missing 530 GB/s go" to "we're at 98.5% of DRAM peak" took two NCU runs and one careful day of code. The reason it worked: every change had a measurement justifying it and a measurement validating it. Vectorization tricks alone are unconvincing without before/after data.

If there's one thing to take from this post, it's that: measure first, optimize second. The 9-column NCU dump that flopped, then the two pipes that landed - that's the same lesson stated twice in different sizes. Pick the right two numbers and the rest follows.


Code: github.com/emre570/cublade

Relevant files: