W8A16 Quantization with LLM.int8-Style Outlier Handling
In Phase 1, we demonstrated that outliers in transformer models are not "statistical noise"βthey are critical carriers of model intelligence and context. When we zeroed outliers in weight matrices, model outputs degraded catastrophically, producing repetitive loops and nonsensical text.
But here's the problem: Quantizing weights from BF16 to INT8 reduces memory by ~50%, but it also destroys outliers through clipping and rounding. How do we get memory savings without sacrificing quality?
The answer comes from the LLM.int8 paper: mixed-precision matrix multiplication with outlier-aware decomposition. In this post, we'll implement this approach in cuBlade and benchmark it on real language models.
The Challenge: Quantization vs Quality
What is W8A16?
W8A16 = Weights in INT8, Activations in BF16/BF16
- Store weights as 8-bit integers (50% memory savings)
- Keep activations in 16-bit floats (preserve dynamic range)
- Multiply using mixed-precision or dequantize on-the-fly
The Naive Approach
# Standard per-channel symmetric quantization
scale = weight.abs().max(dim=0) / 127
weight_int8 = (weight / scale).round().clamp(-128, 127).to(torch.int8)
# At inference:
weight_fp16 = weight_int8.to(torch.float16) * scale
output = activation @ weight_fp16.t()
Problem: Outliers get clipped to [-128, 127] range, losing their magnitude and information content.
Result on Gemma-270M:
- Memory savings: 18.7%
- MSE: 0.02 - 0.18 (varies by input)
- Text quality: Acceptable but degraded
The LLM.int8 Approach
Key Insight: Systematic Outliers
The LLM.int8 paper discovered that outliers in transformer models are systematicβthey consistently appear in specific input features (columns of the weight matrix), not randomly scattered. Their solution is Detect outlier features and handle them separately in BF16, only quantizing the "normal" features to INT8.
Column-wise Outlier Detection
def detect_outlier_columns(weight, threshold=6.0):
"""
Detect which INPUT FEATURES (columns) contain outliers.
weight: [out_features, in_features]
Returns: List of column indices
"""
# For each input feature, check max absolute value
col_max = weight.abs().max(dim=0).values # [in_features]
outlier_mask = col_max > threshold
outlier_cols = outlier_mask.nonzero(as_tuple=True)[0].tolist()
return outlier_cols # Typically 0.1-1% of columns
Column-wise detection exploits the systematic patterns where certain input features consistently produce large activations. This approach is cache-friendly with simple indexing operations, easier to implement than scattered element access, and matches the validated approach from the LLM.int8 paper.
Weight Matrix Decomposition
# Split weight matrix by outlier columns
outlier_cols = [3, 47, 128, ...] # ~0.1% of columns
W_normal = weight[:, normal_cols] # [out, in - num_outliers]
W_outlier = weight[:, outlier_cols] # [out, num_outliers]
# Quantize only the normal part
W_int8 = quantize(W_normal) # β INT8
# Keep outliers in BF16
W_fp16 = W_outlier # β BF16 (unchanged)
Mixed-Precision Forward Pass
def forward(x):
# Split input features
x_normal = x[:, normal_cols] # [batch, in - num_outliers]
x_outlier = x[:, outlier_cols] # [batch, num_outliers]
# Two separate matrix multiplications
y_normal = x_normal @ dequantize(W_int8).t() # INT8 path (quantized)
y_outlier = x_outlier @ W_fp16.t() # BF16 path (outliers)
# Combine results
return y_normal + y_outlier + bias
Outliers preserved at full precision, majority of weights quantized.
Implementation in cuBlade
Step 1: Outlier Detection Utilities
We implemented the LLM.int8 approach in cublade/quantization/outlier_utils.py:
from cublade.quantization import detect_outlier_columns, separate_outliers
# Detect outlier columns
weight = model.fc.weight.data # [out_features, in_features]
outlier_cols = detect_outlier_columns(weight, threshold=6.0)
print(f"Found {len(outlier_cols)} outlier columns") # Typically 0.1-1%
# Split weights
weight_normal, weight_outlier, normal_idx, outlier_idx = separate_outliers(
weight, threshold=6.0
)
Step 2: Enhanced W8A16 Linear Layer
Updated cuBladeW8A16LinearLayer with optional outlier handling:
class cuBladeW8A16LinearLayer(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=True,
dtype=torch.bfloat16,
handle_outliers=False, # Enable outlier handling
outlier_threshold=6.0, # Detection threshold
):
super().__init__()
# ... buffers for quantized weights
if handle_outliers:
# Additional buffers for outliers
self.register_buffer("outlier_weights", None) # BF16
self.register_buffer("normal_cols", None)
self.register_buffer("outlier_cols", None)
Quantization with outlier detection:
def quantize(self, weight_fp):
if self.handle_outliers:
# Detect and separate outliers
w_normal, w_outlier, normal_cols, outlier_cols = separate_outliers(
weight_fp, threshold=self.outlier_threshold
)
# Store outliers in BF16
self.outlier_weights = w_outlier
self.normal_cols = torch.tensor(normal_cols)
self.outlier_cols = torch.tensor(outlier_cols)
# Quantize only normal columns to INT8
qt = quantize_tensor(w_normal, ...)
self.int8_weights = qt.data
else:
# Standard: quantize everything
qt = quantize_tensor(weight_fp, ...)
self.int8_weights.copy_(qt.data)
Forward pass with dual matmul:
def forward(self, x):
if self.handle_outliers and len(self.outlier_cols) > 0:
# Split input features
x_normal = x[:, self.normal_cols]
x_outlier = x[:, self.outlier_cols]
# Quantized path
W_normal = self.int8_weights.to(x.dtype)
y_normal = (x_normal @ W_normal.t()) * self.scales
# BF16 path (outliers)
y_outlier = x_outlier @ self.outlier_weights.t()
# Combine
return y_normal + y_outlier + bias
else:
# Standard quantized matmul
W = self.int8_weights.to(x.dtype)
return (x @ W.t()) * self.scales + bias
Step 3: High-Level API
Simple quantization:
from cublade.quantization import quantize_model
# Without outlier handling (standard W8A16)
model_w8 = quantize_model(
model,
quant_type="w8a16",
exclude_modules=["lm_head"],
handle_outliers=False
)
# With outlier handling (LLM.int8 style)
model_w8_outliers = quantize_model(
model,
quant_type="w8a16",
exclude_modules=["lm_head"],
handle_outliers=True,
outlier_threshold=6.0
)
Experiments and Results
Setup
Model: Gemma-270M (small enough to fit on consumer GPUs) Hardware: RTX 3070 Ti (8GB VRAM) Precision: BF16 baseline Test Cases: 5 diverse prompts (factual, reasoning, code, creative, instruction)
Three-Way Comparison
- BF16 Baseline - Original model (511.36 MiB)
- W8A16 (no outliers) - Standard quantization (415.97 MiB)
- W8A16 (with outliers) - LLM.int8 style (416.15 MiB)
Memory Footprint
βββββββββββββββββββββββββββ¬βββββββββββββββ¬ββββββββββ
β Model β Memory (MiB) β Savings β
βββββββββββββββββββββββββββΌβββββββββββββββΌββββββββββ€
β BF16 Baseline β 511.36 β - β
β W8A16 (no outliers) β 415.97 β β 18.7% β
β W8A16 (with outliers) β 416.15 β β 18.6% β
βββββββββββββββββββββββββββ΄βββββββββββββββ΄ββββββββββ
Key Finding: Outlier handling adds negligible memory overhead (only 0.18 MiB) because outliers represent <1% of weights.
Quality Metrics (MSE on Logits)
We tested on 5 diverse prompts. Here are representative results:
Simple Factual ("The capital of France is"):
W8A16 (no outliers): MSE = 0.024
W8A16 (with outliers): MSE = 0.018
(25% improvement)
Reasoning ("If John has 3 apples and gives 2 to Mary, he has"):
W8A16 (no outliers): MSE = 0.038
W8A16 (with outliers): MSE = 0.029
(24% improvement)
Code Generation ("Write a Python function to calculate factorial:"):
W8A16 (no outliers): MSE = 0.061
W8A16 (with outliers): MSE = 0.048
(21% improvement)
Creative Writing ("Once upon a time, in a distant galaxy, there lived a"):
W8A16 (no outliers): MSE = 0.041
W8A16 (with outliers): MSE = 0.032
(22% improvement)
Instruction Following ("Question: What are three benefits of exercise?"):
W8A16 (no outliers): MSE = 0.175
W8A16 (with outliers): MSE = 0.138
(21% improvement)
Key Observations
Outlier handling consistently improves quality by 20-25% across all prompt types, though the baseline error varies dramatically by input complexityβsimple factual queries show MSE around 0.024 while complex instructions hit 0.175. Gemma-270M has relatively few outliers, so larger models (7B+) will show more dramatic improvements. The memory overhead is negligible at <1% because outliers represent such a tiny fraction of weights.
Generated Text Quality
Example: Reasoning Task
Prompt: "If John has 3 apples and gives 2 to Mary, he has"
BF16 Baseline:
"If John has 3 apples and gives 2 to Mary, he has 1 apple left."
W8A16 (no outliers):
"If John has 3 apples and gives 2 to Mary, he has 1 apple. If he has 1 apple..."
W8A16 (with outliers):
"If John has 3 apples and gives 2 to Mary, he has 1 apple left."
Without outlier handling, the model sometimes repeats or extends unnecessarily. With outliers preserved, output matches baseline.
Performance Analysis
Speed Considerations
The current implementation delivers ~19% memory savings and improved quality with outliers, but it's 3x slower than BF16 baseline. Our implementation stores weights in INT8 but dequantizes to BF16 on every forward pass because we don't have a custom INT8ΓBF16 matmul kernel yet:
# Current (slow):
W_fp16 = self.int8_weights.to(torch.float16) # Conversion overhead!
output = x @ W_fp16.t() # BF16 Γ BF16 matmul
What we need for speedup: Custom CUDA/Triton kernel that multiplies INT8 weights directly with BF16 activations:
# Future (fast):
output = int8_fp16_matmul(x, self.int8_weights, self.scales)
This requires implementing a specialized kernel that:
- Reads INT8 weights directly
- Multiplies with BF16 activations
- Accumulates in FP32
- Applies per-channel scales
- Returns BF16 output
Why this is hard: PyTorch don't natively support INT8ΓBF16 out of the box. You need either custom CUDA kernel or specialized Triton implementation.
This implementation shines in memory-constrained deployments where you need to serve multiple models on one GPU, batch inference scenarios where memory bandwidth matters more than compute, model storage and distribution where checkpoint size is critical, and situations where quality preservation matters more than speed. It's not ideal for latency-critical applications or single model deployments where VRAM isn't a bottleneckβuntil we add the custom kernel.
Implementation Lessons
1. Outlier Detection Threshold
We used threshold=6.0 (from LLM.int8 paper), but this varies by model:
# Inspect outliers in your model
from cublade.quantization import detect_outlier_columns
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
outliers = detect_outlier_columns(module.weight, threshold=6.0)
print(f"{name}: {len(outliers)} outliers / {module.weight.shape[1]} features")
Rule of thumb:
- Small models (<1B): ~0.1-0.5% outliers
- Medium models (1-7B): ~0.5-1% outliers
- Large models (7B+): ~1-3% outliers
2. Per-Value vs Per-Column
We chose per-column outlier detection (LLM.int8 style) over per-value:
Per-column (our choice):
- Detects entire input features as outliers
- Cache-friendly, simple indexing
- Matches paper approach
Per-value (alternative):
- More granular (flag individual weights)
- Better memory efficiency
- Complex gather/scatter operations
For most use cases, per-column is simpler and good enough.
3. Integration with Quantization Framework
Our quantize_model() API makes it trivial:
# One parameter toggles outlier handling
model = quantize_model(model, handle_outliers=True)
This clean API hides complexity:
- Outlier detection runs automatically
- Column splitting handled internally
- Dual matmul path selected at runtime
- No manual bookkeeping required
Future Work
Speed Optimization: The missing piece is a custom INT8ΓBF16 Triton kernel that multiplies INT8 weights directly with BF16 activations, accumulates in INT32, applies per-channel scaling, and outputs BF16 results. This will be integrated into cuBlade soon and should deliver 1.5-2x speedup over BF16 baseline by cutting memory bandwidth in half.
Lower Bit Widths: Once W8A16 is fast, push to W4A16 using GPTQ/AWQ-style group quantization for better quality at 4-bit, and explore mixed precision approaches where most layers use W4 but outlier-heavy layers stay at W8.
Conclusion
We've implemented LLM.int8-style W8A16 quantization in cuBlade with ~19% memory savings and 20-25% quality improvement when outliers are preserved. The current implementation prioritizes correctness over speedβit's 3x slower than BF16 because we dequantize on every forward pass. But this sets the foundation: once we add a custom INT8ΓBF16 Triton kernel, we'll get both memory efficiency and speed gains. The key insight holds: preserving just 0.1-1% of weights as outliers dramatically improves model quality, and larger models benefit even more than our Gemma-270M test case.
References
- LLM.int8() paper - Dettmers et al., 2022
- Phase 1: Understanding Outliers - Our outlier analysis
- cuBlade GitHub - Source code and examples
Check out the full implementation in cublade/quantization/ and run the examples in examples/quantization/.