Unleashing Blackwell's 4-bit: a surgical look at MXFP4 and NVFP4
If you do kernel-level inference optimization, you eventually hit the cold truth: the enemy is not compute, it is VRAM bandwidth. FP32 numbers are huge in memory. The fix is to squeeze them into 4-bit boxes.
The internet is full of repos that do this. They say "find amax, divide by this, shift the bits, here is your FP4," and move on. I copied those formulas into my own code at first - and could not have told you what a single line of it did. You can stitch API calls together as a programmer. But this close to the hardware you have to operate like a surgeon, between the bits. So I dropped the formulas and went down to first principles, asking the only question that matters: what is the logic of this thing? Couple things:
The move underneath all of it: read bits, not numbers
Before the four facts, the one move they all depend on. Low-bit quantization is where you stop treating a float as a number and start treating its 32 bits as raw storage you can carve up. Take the value 6.0:
__float_as_uint(6.0f) returns 0x40C00000. That is not a conversion, and not even one instruction - it is the same 32 bits in the same register, relabeled. You told the compiler to read them through the integer lens instead of the float lens. std::bit_cast (C++20) and __uint_as_float going back are the same trick. Zero instructions, zero memory traffic.
That is the surgeon's move: operate under the abstraction, not above it. FP8 never forced it on me - it kept its scale as a full FP32 number and cast through hardware intrinsics, so the abstraction held. FP4 strips that away.
1. "Why can't I store 2.3?" (FP4 limits)
FP4 E2M1 is 4 bits: 1 sign, 2 exponent, 1 mantissa.
The thing that nagged at me first: normally I take a number like 2.3, round it a little, and move on. In the FP4 world it does not work that way.
The whole format hinges on that single mantissa bit. One bit is two states, so the mantissa multiplier is either 1.0 or 1.5 - nothing between. That fixes the entire positive grid at 8 values, {0, 0.5, 1, 1.5, 2, 3, 4, 6}, and the largest is 6.0 (1.5 x 2^2).
So 2.3 does not round to "2.3-ish." It snaps to 2.0 or 3.0. The coarse, uneven gaps are not a rounding bug waiting to be fixed - they are the format. Every decision after this is about working around an 8-value grid.
2. The mysterious -2 trick
A block's values rarely sit inside {0..6} as-is. Say the largest value in our block (amax) is 24.0. To fit that into FP4's capacity of 6.0 we need to divide by 4. On paper it is trivial: scale = amax / 6.0.
But I look at the code and there is no such division. Instead there is this absurd-looking bit operation:
scale_byte = ((__float_as_uint(amax) >> 23) & 0xFF) - 2;
"Where did this -2 come from, why are we not dividing by 6?" I wrestled with this for a while. But the secret was right there.
The scale is stored as UE8M0: 8 exponent bits, zero mantissa bits. It can only represent exact powers of two. There is no "divide by 6" it can hold - only divide by 1, 2, 4, 8, and so on.
So the kernel does not divide by 6. It divides by 4 - the nearest power of two below 6 - and accepts the gap. Dividing by 4 is just subtracting 2 from the exponent, and that is the entire -2: pull amax's 8-bit exponent field out (>> 23 & 0xFF), subtract 2. The 1.5 factor in 6.0 = 1.5 x 2^2 is simply thrown away. The cost is an occasional saturated value, never a correctness bug.
3. Where did the bias come from?
We’ve scaled the number, and now it’s time to pack it into that 4-bit S-EE-M box. I do the math, and my real exponent comes out to 2. Naturally, I expect to write binary 2 into the EE field. But the format aggressively demands I write a 3. Why?
If we were dealing with standard 32-bit (FP32) floating-point numbers, the answer would be obvious. I learned that engineers absolutely hate dealing with negative signs when the GPU needs to compare massive numbers at lightning speed. To bypass this, they invented a "hardware thermometer," fixing the zero point at 127 (the Bias).
But the FP4 E2M1 world is entirely different. In E2M1, negative exponents simply do not exist. The exponents can only be 0, 1, or 2. So, if we don't have negative exponents to store or compare, why the hell is this Bias rule still hanging around?
The answer lies in a brilliant eviction strategy designed to clear out the ground floor of the format.
E2M1's 2-bit exponent field only has four possible codes:
00, 01, 10, and 11
The format explicitly dictates: the 00 code will never be handed to a normal number. That slot is strictly reserved for zero and the format's single subnormal value, 0.5. Normal numbers are only allowed to use 01, 10, and 11.
This is exactly where the Bias steps in. It completely abandons its original purpose of "avoiding negative numbers" and transforms into a pure "shift" operation. To forcefully keep normal numbers out of that forbidden 00 code, a bias of 1 is added to every real exponent before storing it. A real exponent of 2 gets stored as 2 + 1 = 3 (11). The smallest normal number (1.0), which has a real exponent of 0, gets stored as 0 + 1 = 1 (01).
The Payoff: Thanks to this hack and that reserved 00 code, an incredible hardware 'coincidence' (or frankly, a masterpiece of design) emerges. When you line up the 8 distinct values FP4 can actually hold from smallest to largest, their hardware bit equivalents line up perfectly in sequential order from 0 to 7!
Because of this, the quantization kernel doesn't waste time assembling S-EE-M bits by hand. It simply drops the values through an if-else ladder of 7 midpoints, takes the resulting index, ORs it with the sign bit, and calls it a day:
// midpoints between the 8 grid values: 0.25 0.75 1.25 1.75 2.5 3.5 5.0
if (a < 0.25f) idx = 0; // -> 0.0
else if (a < 0.75f) idx = 1; // -> 0.5
else if (a < 1.25f) idx = 2; // -> 1.0
...
else idx = 7; // -> 6.0
return sign | idx;
Absolute cinema.
Two nibbles to a byte, because VRAM has no 4-bit address
FP4 is 4 bits, but VRAM has no 4-bit address - the smallest addressable unit is the 8-bit byte. So two FP4 values have to share one byte:
byte = (hi << 4) | (lo & 0x0F);
Shift one nibble into the high half, mask the other into the low half, OR them together. The line that looked like magic is exactly that and nothing more.
The next two parts are about making this fast. This part was about earning the right to.
Part 2: Kernels, Assemble!
Part 1 was about the format - the boxes the numbers go into. Part 2 is about the kernel that actually moves them.
1. Correct is not fast
v1's each warp owns one 32-element block: 32 lanes, lane i owns element i, a 5-step warp-shuffle reduces the 32 magnitudes down to the block amax. It is correct. I was pretty happy with it.
Then I ran it through NCU:
v1 MXFP4 quantizer @ 64M elements
occupancy: ~91%
compute: ~45%
DRAM busy: ~36%
Nothing is full. And that is the whole problem.
A kernel lives in exactly one of three states.
1. Compute-bound: the math pipes are saturated.
2. Memory-bound: the DRAM bus is saturated.
3. Latency-bound: neither is saturated and the GPU is just waiting.
v1 is the third one, and the tell is the pair of numbers, not either one alone: 91% occupancy means every warp slot on the chip is filled, but 36% throughput means those warps are sitting on their hands most of the time.
Picture a kitchen packed wall to wall with cooks. That is your 91% occupancy. But each cook chops one onion, then stands there waiting for the next delivery truck to pull up. A full kitchen is not a busy kitchen.
2. Which pipe do we actually want full?
My first instinct was the obvious one: compute is only at 45%, so push compute up. That instinct is wrong, and it is wrong in a way worth understanding.
Look at what a quantizer does to a single element: read it, take its absolute value, one max, one multiply, snap it to the grid. That is nothing. A quantizer is not a calculator, it is a data-mover. Its runtime is the bytes it shovels, not the math it does.
So a data-mover's destiny is to be memory-bound, and low compute is not a bug to fix - it is the correct, healthy state. That flips the win condition completely. The target is not "make compute high." The target is "make DRAM ~90%." A GEMM wants to be compute-bound. A data-mover wants to be memory-bound. Figuring out which animal your kernel is is step zero, and I spent about ten minutes being the wrong animal.
3. The two-stage trap: count the bytes
Here is a clever-sounding idea. Do not jump straight FP16 to FP4 - go in two easy hops, FP16 to FP8, then FP8 to FP4. Each hop is simpler than the full jump.
It sounds reasonable right up until you count the bytes. For n input elements:
That is roughly 1.8x more bytes. The FP8 intermediate gets written all the way out to DRAM and read straight back in, a full round trip, for nothing.
And here is the rule that governs the entire rest of this part: on a memory-bound kernel, the bytes are the runtime. 1.8x the bytes is 1.8x the time, full stop. There is no clever instruction scheduling that escapes the byte count. So count the bytes first, before you write a single line of the kernel. The byte total is your runtime budget, decided before you start.
4. One load does not mean one element
v1 had a quiet mistake hiding in plain sight: it loads one fp16 value per thread. 16 bits at a time.
But the element being 16 bits wide does not cap the load at 16 bits. A single 128-bit load instruction pulls 8 fp16 values in one shot.
The books on a shelf are 2 bytes thick each. Nothing about that forces you to carry one book per trip to the desk. Your arms hold 8 of them. v1 walked back and forth carrying one book at a time. v2 grabs the whole armful.
5. A warp is one instruction wearing 32 hats
One more thing to lock in before the redesign. When 32 lanes each issue "load 16 bytes," that is not 32 load instructions racing each other. It is one warp-wide load instruction. The 32 lanes each hand in an address, and the memory system fuses 32 neighbouring 16-byte requests into 512 contiguous bytes pulled off DRAM in a single coalesced transaction.
Wide loads and coalescing are not two separate tricks you stack. They are the same mechanism viewed from two angles. SIMT issues the instruction once for the whole warp - that is the entire idea.
6. The v2 redesign
So the layout changes:
v1 was one warp per block. v2 is one warp per 8 blocks: 4 lanes cooperate on each 32-element block, and each lane pulls its 8 elements with a single 128-bit load. Two wins fall straight out of that one layout change.
Win 1: the reduction collapses. v1 needed a 5-step, 32-lane shuffle chain to find the block amax. In v2 each lane already holds 8 values sitting in its own registers, so it takes the max over those 8 for free - no communication at all - and then only needs a 2-step shuffle across the 4 lanes that share the block.
float amax = local_amax; // free: register-local max of 8
amax = fmaxf(amax, __shfl_xor_sync(0xffffffffu, amax, 1)); // across 2 lanes
amax = fmaxf(amax, __shfl_xor_sync(0xffffffffu, amax, 2)); // across 4 -> done
Win 2: packing loses its shuffle. In v1 the even lane had to __shfl its neighbour's nibble across the warp before it could even build a byte. In v2 each lane holds 8 consecutive nibbles, so it packs 4 whole bytes entirely inside its own registers and writes them out as one 32-bit store. Zero cross-lane traffic.
The mental picture: v1 is 32 people pairing off over 5 rounds to find who is tallest in the room. v2 is 4 people, each privately eyeballing their own stack of 8 cards, then just 2 quick rounds between the 4 of them. Fewer people in the conversation, shorter chain, faster answer.
Back to NCU:
v2 MXFP4 quantizer @ 64M elements
DRAM busy: ~91%
latency: 170 us
vs v1: 2.7x faster
91% DRAM busy. Memory-bound. Exactly the animal a data-mover is supposed to be. And notice what got us here: not a single exotic instruction. Two numbers read off an NCU report, one correctly named animal, and a byte count.
Part 3: NVFP4, or why one scale grew a second floor
Part 2's kernel quantizes MXFP4 and parks itself against the DRAM wall. Job done - for one format. Then NVFP4 walks in.
1. Same nibble, different wrapper
Here is the thing that surprised me: MXFP4 and NVFP4 store the exact same 4-bit element. The same E2M1 nibble, the same 8-value grid {0, .5, 1, 1.5, 2, 3, 4, 6} from Part 1. Nothing about the element changes. Not one bit.
What changes is the wrapper - the block-scaling scheme bolted around the nibble:
block size block scale extra
MXFP4 32 UE8M0 (1 byte) none
NVFP4 16 e4m3 (1 byte) 1 FP32 global scale, per tensor
Two differences: NVFP4 uses half the block size, and it carries an extra per-tensor scale that MXFP4 simply does not have. That extra scale is the whole puzzle. Why does NVFP4 need a second scale when MXFP4 gets away with one?
2. The chaperone: why one scale was not enough
MXFP4's block scale is UE8M0 - pure exponent, no mantissa. From Part 1 you already know what that buys: it only stores powers of two. But it stores every power of two, from 2^-127 to 2^+127. Whatever scale a block could possibly need, UE8M0 can name it. One scale per block, done.
NVFP4's block scale is FP8 e4m3, and e4m3 has a 3-bit mantissa. That is the upgrade - a mantissa lets the scale land on values between the powers of two, so it tracks each block's true magnitude far more tightly. Tighter scale, lower quantization error.
But a mantissa costs range. e4m3 tops out at 448. And there is the trap: a tensor with a wide spread of magnitudes can contain a block whose required scale is bigger than 448. That scale overflows e4m3, and the whole block is destroyed - not because the FP4 nibbles were too coarse, but because the scale itself overflowed its own storage.
The fix is a second scale, one floor up: a single per-tensor FP32 global scale that pre-shrinks the whole tensor so every block's scale lands safely inside e4m3's range. Two range problems, one scale each:
- the block scale (e4m3) keeps each FP4 nibble inside
{0..6} - the global scale (FP32) keeps each block scale inside e4m3's range
It is a clean two-level hierarchy. Not a vague "shrink everything" knob - two scales, each solving a range problem at its own floor.
Here is the analogy that made it click. Picture a box of cheap plastic rulers, one per block of 16 numbers. Each ruler is precise - fine millimetre ticks, that is the e4m3 mantissa - but each one only spans 1cm to 30cm. Now your tensor has blocks the size of a grain of sand and blocks the size of a building. No single cheap ruler measures both. So before you measure anything, you put on a pair of zoom glasses - the global scale, set once from the whole tensor's amax - that resize the entire world until everything, sand and skyscrapers alike, fits between 1cm and 30cm. Now every cheap ruler works.
MXFP4 needs no glasses. Its UE8M0 ruler already stretches from a grain of sand to a galaxy on its own.
3. The recipe, and the thing that cancels
This is the NVFP4 quant recipe, all in FP32:
global_scale = (6 * 448) / amax_tensor // 2688 / amax, one per tensor
per 16-element block:
vec_max = max(|x|) over the block
block_scale = e4m3_round( global_scale * vec_max / 6 ) // the stored e4m3 byte
out_scale = global_scale / e4m3_decode(block_scale)
nibble = cvt_rne_e2m1( x * out_scale )
Now watch the element quantization itself, the x * out_scale line. Expand it:
x * out_scale = x * global_scale / block_scale
~= x * global_scale / (global_scale * vec_max / 6)
= x * 6 / vec_max
global_scale cancels itself out. It sits in the numerator and the denominator, and it vanishes. The nibble ends up being plain x * 6 / vec_max - exactly the per-block scaling you would expect, with no global term anywhere in it.
So what does the global scale actually do? Exactly one thing, one line up: it sets the value of the stored block_scale byte, pushing that e4m3 number into a range where it does not overflow. That is its entire job. It is the chaperone - it never touches the data, it just makes sure the block scale behaves. The algebra proves the framing from section 2: the global scale is not in any element's path, it only governs the block scale's storage.
4. A smaller block makes the reduction easier, not harder
You might expect the 16-element block to be more annoying to handle than MXFP4's 32. The opposite is true, and it is a clean little result.
The v2 layout from Part 2 carries straight over: each lane does one 128-bit load = 8 fp16 elements. The only thing that changes is how many lanes have to cooperate on one block:
A smaller block means fewer lanes share it, which means a shorter reduction chain. NVFP4's per-block reduction is a single shuffle. Both kernels still have one warp cover 256 elements - only the fan-out inside a block shrinks. NVIDIA's own kernel lands on the same answer: CVT_FP4_NUM_THREADS_PER_SF == 2, two lanes per scale factor.
5. The global scale costs a second pass: count the bytes, again
Remember the rule from Part 2: on a memory-bound kernel, the bytes are the runtime. NVFP4 hands us a fresh bill.
The two-level recipe needs amax over the whole tensor to build the global scale - and you need that global scale before you can touch element 0. It cannot be fused into the quant pass. So NVFP4 quant is two separate kernel launches: an amax-reduce, then the block-quantize.
Count the bytes for N elements:
caller provides the global scale: read 2N + write 0.5N + 0.06N (scales) ~= 2.56N
kernel computes the global scale: + a whole extra read of x (2N) ~= 4.56N
Roughly 1.8x - the same ratio as Part 2's two-stage trap, and the same cause: one extra full read of the input. Measured on the 5080 at 64M elements, the amax pass takes 156us and the quant pass 164us. The extra read really is about half the runtime, exactly as the byte count said it would be before I ran anything.
So why is this 1.8x acceptable when the two-stage trap's 1.8x was not? Because this is the offline weight quantizer. It runs once per weight tensor, at model-conversion time. You pay the 1.8x a single time, ever, and never again at inference. A hot-path quantizer - dynamic activations, every forward pass - would not do this; it would take a pre-calibrated global scale from the caller and skip the amax pass entirely, the 2.56N path. The kernel supports both. Same byte arithmetic, opposite verdict, because the context changed.
6. Who does the rounding: software or silicon
One more divergence, and it is in the rounding. MXFP4's quantizer does its FP4 rounding in software - that if (a < 0.25f) ... ladder from Part 1, strict < the whole way, ties rounding toward the larger magnitude.
NVFP4's kernel does not do that. It hands the rounding to the hardware: the PTX instruction cvt.rn.satfinite.e2m1x2.f32, which rounds to nearest even and saturates at 6.0. Letting the silicon do the convert is faster, and it is the format's intended path.
Round-to-nearest-even is a different rule from MXFP4's ties-toward-larger, and the non-uniform E2M1 grid makes the difference bite. The 7 midpoints between the 8 grid values are {0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0}, and at a tie RNE snaps to whichever neighbour has the even mantissa bit. Across the grid that comes out as an alternating rule - round down at one midpoint, up at the next - not one consistent direction.
And one quirk worth knowing: the sign of zero. A tiny negative value that rounds all the way down to zero magnitude still keeps its sign bit through the hardware convert, so it encodes -0 (code 0x8), not +0.
7. Where the pipes land
NCU, the only scoreboard that matters:
NVFP4, RTX 5080, kernel-level
amax-reduce ~93% DRAM busy memory-bound
quant (block) ~91-94% DRAM busy memory-bound, matches MXFP4 v2's 91%
dequant ~78% DRAM busy
Quant and amax land against the DRAM wall, right where Part 2 taught us a data-mover belongs. Dequant lands lower, ~78%, and it is the same write-asymmetry cap the FP8 dequantizer hit: dequant writes 2N bytes of fp16 but only reads ~0.56N (the packed nibbles plus the scales). It is write-dominant, and the GDDR bus favours reads at peak. A write-heavy single-pass kernel leaves a few percent of DRAM on the table by design.
One honesty caveat on these numbers, and it bites if you ignore it. At 16M elements the working set fits in L2, and the figures flatter you - a chunk of the traffic never reaches DRAM at all. The honest, DRAM-bound numbers are the past-L2 ones at 64M and 100M. If a quantizer benchmark looks suspiciously good, check whether your tensor just fit in cache before you celebrate.
Sidebar: the -arch=sm_120a trap
A Blackwell-consumer gotcha worth a paragraph for anyone writing their own SM120 kernels. NVFP4's kernel needs the cvt.e2m1x2 PTX opcode, which is sm_120a-only - an "architecture-specific" instruction. The intuitive way to ask for it is -arch=sm_120a.
That flag is a trap in object-compile mode. torch.utils.cpp_extension always compiles with -c (compile to object). In that mode, -arch=sm_120a silently emits base compute_120 PTX and quietly drops every a-only opcode, cvt.e2m1x2 included. You find out only when ptxas rejects the result: not supported on .target 'sm_120'. Compiling with -cubin instead hides the problem, which makes it doubly confusing. The fix is the explicit gencode form:
-gencode=arch=compute_120a,code=sm_120a
That keeps the architecture-specific PTX. This was a latent bug in cublade's JIT loader - it surfaced the moment NVFP4 needed the opcode.
Closing: one nibble, three lessons
Three parts, and really one idea told at three scales.
Part 1: if you do not know where the speed comes from - the dropped 1.5 in the -2, the 8-value grid behind one mantissa bit, the reserved 00 code - you do not know your kernel's limits. Copying formulas does not teach you that. Carving the bits does.
Part 2: a quantizer is a data-mover, its destiny is memory-bound, and on a memory-bound kernel the byte count is the runtime. v1 to v2 was 36% to 91% DRAM, 2.7x, and every step was a number read off NCU, not a guess.
Part 3: MXFP4 and NVFP4 share the identical 4-bit element. The whole difference is the wrapper. NVFP4 spends a quarter-bit per element to put a mantissa into its block scale, and the price of that mantissa is a range problem - solved by a second scale that chaperones the first. Two range problems, one scale each.
Both formats are productionized in cublade now: quant and dequant, fp16 and bf16. And there is one thread left dangling. At the GEMM level the two formats diverge again - SM120's matmul atom splits into .kind::mxf4 and .kind::mxf4nvf4 - and MXFP4 has no SM120 cuBLASLt kernel at all. That gap is the next post. This one stopped at quant and dequant, where the format stops being a diagram and starts being bytes.
Code: github.com/emre570/cublade