Emre's Blog

W8A16 Quantization with LLM.int8-Style Outlier Handling

In Phase 1, we demonstrated that outliers in transformer models are not "statistical noise"β€”they are critical carriers of model intelligence and context. When we zeroed outliers in weight matrices, model outputs degraded catastrophically, producing repetitive loops and nonsensical text.

But here's the problem: Quantizing weights from BF16 to INT8 reduces memory by ~50%, but it also destroys outliers through clipping and rounding. How do we get memory savings without sacrificing quality?

The answer comes from the LLM.int8 paper: mixed-precision matrix multiplication with outlier-aware decomposition. In this post, we'll implement this approach in cuBlade and benchmark it on real language models.


The Challenge: Quantization vs Quality

What is W8A16?

W8A16 = Weights in INT8, Activations in BF16/BF16

The Naive Approach

# Standard per-channel symmetric quantization
scale = weight.abs().max(dim=0) / 127
weight_int8 = (weight / scale).round().clamp(-128, 127).to(torch.int8)

# At inference:
weight_fp16 = weight_int8.to(torch.float16) * scale
output = activation @ weight_fp16.t()

Problem: Outliers get clipped to [-128, 127] range, losing their magnitude and information content.

Result on Gemma-270M:


The LLM.int8 Approach

Key Insight: Systematic Outliers

The LLM.int8 paper discovered that outliers in transformer models are systematicβ€”they consistently appear in specific input features (columns of the weight matrix), not randomly scattered. Their solution is Detect outlier features and handle them separately in BF16, only quantizing the "normal" features to INT8.

Column-wise Outlier Detection

def detect_outlier_columns(weight, threshold=6.0):
    """
    Detect which INPUT FEATURES (columns) contain outliers.
  
    weight: [out_features, in_features]
    Returns: List of column indices
    """
    # For each input feature, check max absolute value
    col_max = weight.abs().max(dim=0).values  # [in_features]
    outlier_mask = col_max > threshold
    outlier_cols = outlier_mask.nonzero(as_tuple=True)[0].tolist()
  
    return outlier_cols  # Typically 0.1-1% of columns

Column-wise detection exploits the systematic patterns where certain input features consistently produce large activations. This approach is cache-friendly with simple indexing operations, easier to implement than scattered element access, and matches the validated approach from the LLM.int8 paper.

Weight Matrix Decomposition

# Split weight matrix by outlier columns
outlier_cols = [3, 47, 128, ...]  # ~0.1% of columns

W_normal = weight[:, normal_cols]    # [out, in - num_outliers]
W_outlier = weight[:, outlier_cols]  # [out, num_outliers]

# Quantize only the normal part
W_int8 = quantize(W_normal)  # β†’ INT8
# Keep outliers in BF16
W_fp16 = W_outlier  # β†’ BF16 (unchanged)

Mixed-Precision Forward Pass

def forward(x):
    # Split input features
    x_normal = x[:, normal_cols]    # [batch, in - num_outliers]
    x_outlier = x[:, outlier_cols]  # [batch, num_outliers]
  
    # Two separate matrix multiplications
    y_normal = x_normal @ dequantize(W_int8).t()  # INT8 path (quantized)
    y_outlier = x_outlier @ W_fp16.t()            # BF16 path (outliers)
  
    # Combine results
    return y_normal + y_outlier + bias

Outliers preserved at full precision, majority of weights quantized.


Implementation in cuBlade

Step 1: Outlier Detection Utilities

We implemented the LLM.int8 approach in cublade/quantization/outlier_utils.py:

from cublade.quantization import detect_outlier_columns, separate_outliers

# Detect outlier columns
weight = model.fc.weight.data  # [out_features, in_features]
outlier_cols = detect_outlier_columns(weight, threshold=6.0)
print(f"Found {len(outlier_cols)} outlier columns")  # Typically 0.1-1%

# Split weights
weight_normal, weight_outlier, normal_idx, outlier_idx = separate_outliers(
    weight, threshold=6.0
)

Step 2: Enhanced W8A16 Linear Layer

Updated cuBladeW8A16LinearLayer with optional outlier handling:

class cuBladeW8A16LinearLayer(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        dtype=torch.bfloat16,
        handle_outliers=False,      # Enable outlier handling
        outlier_threshold=6.0,      # Detection threshold
    ):
        super().__init__()
        # ... buffers for quantized weights
  
        if handle_outliers:
            # Additional buffers for outliers
            self.register_buffer("outlier_weights", None)  # BF16
            self.register_buffer("normal_cols", None)
            self.register_buffer("outlier_cols", None)

Quantization with outlier detection:

def quantize(self, weight_fp):
    if self.handle_outliers:
        # Detect and separate outliers
        w_normal, w_outlier, normal_cols, outlier_cols = separate_outliers(
            weight_fp, threshold=self.outlier_threshold
        )
  
        # Store outliers in BF16
        self.outlier_weights = w_outlier
        self.normal_cols = torch.tensor(normal_cols)
        self.outlier_cols = torch.tensor(outlier_cols)
  
        # Quantize only normal columns to INT8
        qt = quantize_tensor(w_normal, ...)
        self.int8_weights = qt.data
    else:
        # Standard: quantize everything
        qt = quantize_tensor(weight_fp, ...)
        self.int8_weights.copy_(qt.data)

Forward pass with dual matmul:

def forward(self, x):
    if self.handle_outliers and len(self.outlier_cols) > 0:
        # Split input features
        x_normal = x[:, self.normal_cols]
        x_outlier = x[:, self.outlier_cols]
  
        # Quantized path
        W_normal = self.int8_weights.to(x.dtype)
        y_normal = (x_normal @ W_normal.t()) * self.scales
  
        # BF16 path (outliers)
        y_outlier = x_outlier @ self.outlier_weights.t()
  
        # Combine
        return y_normal + y_outlier + bias
    else:
        # Standard quantized matmul
        W = self.int8_weights.to(x.dtype)
        return (x @ W.t()) * self.scales + bias

Step 3: High-Level API

Simple quantization:

from cublade.quantization import quantize_model

# Without outlier handling (standard W8A16)
model_w8 = quantize_model(
    model,
    quant_type="w8a16",
    exclude_modules=["lm_head"],
    handle_outliers=False
)

# With outlier handling (LLM.int8 style)
model_w8_outliers = quantize_model(
    model,
    quant_type="w8a16",
    exclude_modules=["lm_head"],
    handle_outliers=True,
    outlier_threshold=6.0
)

Experiments and Results

Setup

Model: Gemma-270M (small enough to fit on consumer GPUs) Hardware: RTX 3070 Ti (8GB VRAM) Precision: BF16 baseline Test Cases: 5 diverse prompts (factual, reasoning, code, creative, instruction)

Three-Way Comparison

  1. BF16 Baseline - Original model (511.36 MiB)
  2. W8A16 (no outliers) - Standard quantization (415.97 MiB)
  3. W8A16 (with outliers) - LLM.int8 style (416.15 MiB)

Memory Footprint

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Model                   β”‚ Memory (MiB) β”‚ Savings β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ BF16 Baseline           β”‚ 511.36       β”‚ -       β”‚
β”‚ W8A16 (no outliers)     β”‚ 415.97       β”‚ ↓ 18.7% β”‚
β”‚ W8A16 (with outliers)   β”‚ 416.15       β”‚ ↓ 18.6% β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Key Finding: Outlier handling adds negligible memory overhead (only 0.18 MiB) because outliers represent <1% of weights.

Quality Metrics (MSE on Logits)

We tested on 5 diverse prompts. Here are representative results:

Simple Factual ("The capital of France is"):

W8A16 (no outliers):   MSE = 0.024
W8A16 (with outliers): MSE = 0.018
(25% improvement)

Reasoning ("If John has 3 apples and gives 2 to Mary, he has"):

W8A16 (no outliers):   MSE = 0.038
W8A16 (with outliers): MSE = 0.029
(24% improvement)

Code Generation ("Write a Python function to calculate factorial:"):

W8A16 (no outliers):   MSE = 0.061
W8A16 (with outliers): MSE = 0.048
(21% improvement)

Creative Writing ("Once upon a time, in a distant galaxy, there lived a"):

W8A16 (no outliers):   MSE = 0.041
W8A16 (with outliers): MSE = 0.032
(22% improvement)

Instruction Following ("Question: What are three benefits of exercise?"):

W8A16 (no outliers):   MSE = 0.175
W8A16 (with outliers): MSE = 0.138
(21% improvement)

Key Observations

Outlier handling consistently improves quality by 20-25% across all prompt types, though the baseline error varies dramatically by input complexityβ€”simple factual queries show MSE around 0.024 while complex instructions hit 0.175. Gemma-270M has relatively few outliers, so larger models (7B+) will show more dramatic improvements. The memory overhead is negligible at <1% because outliers represent such a tiny fraction of weights.

Generated Text Quality

Example: Reasoning Task

Prompt: "If John has 3 apples and gives 2 to Mary, he has"

BF16 Baseline:
"If John has 3 apples and gives 2 to Mary, he has 1 apple left."

W8A16 (no outliers):
"If John has 3 apples and gives 2 to Mary, he has 1 apple. If he has 1 apple..."

W8A16 (with outliers):
"If John has 3 apples and gives 2 to Mary, he has 1 apple left."

Without outlier handling, the model sometimes repeats or extends unnecessarily. With outliers preserved, output matches baseline.


Performance Analysis

Speed Considerations

The current implementation delivers ~19% memory savings and improved quality with outliers, but it's 3x slower than BF16 baseline. Our implementation stores weights in INT8 but dequantizes to BF16 on every forward pass because we don't have a custom INT8Γ—BF16 matmul kernel yet:

# Current (slow):
W_fp16 = self.int8_weights.to(torch.float16)  # Conversion overhead!
output = x @ W_fp16.t()  # BF16 Γ— BF16 matmul

What we need for speedup: Custom CUDA/Triton kernel that multiplies INT8 weights directly with BF16 activations:

# Future (fast):
output = int8_fp16_matmul(x, self.int8_weights, self.scales)

This requires implementing a specialized kernel that:

  1. Reads INT8 weights directly
  2. Multiplies with BF16 activations
  3. Accumulates in FP32
  4. Applies per-channel scales
  5. Returns BF16 output

Why this is hard: PyTorch don't natively support INT8Γ—BF16 out of the box. You need either custom CUDA kernel or specialized Triton implementation.

This implementation shines in memory-constrained deployments where you need to serve multiple models on one GPU, batch inference scenarios where memory bandwidth matters more than compute, model storage and distribution where checkpoint size is critical, and situations where quality preservation matters more than speed. It's not ideal for latency-critical applications or single model deployments where VRAM isn't a bottleneckβ€”until we add the custom kernel.


Implementation Lessons

1. Outlier Detection Threshold

We used threshold=6.0 (from LLM.int8 paper), but this varies by model:

# Inspect outliers in your model
from cublade.quantization import detect_outlier_columns

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        outliers = detect_outlier_columns(module.weight, threshold=6.0)
        print(f"{name}: {len(outliers)} outliers / {module.weight.shape[1]} features")

Rule of thumb:

2. Per-Value vs Per-Column

We chose per-column outlier detection (LLM.int8 style) over per-value:

Per-column (our choice):

Per-value (alternative):

For most use cases, per-column is simpler and good enough.

3. Integration with Quantization Framework

Our quantize_model() API makes it trivial:

# One parameter toggles outlier handling
model = quantize_model(model, handle_outliers=True)

This clean API hides complexity:


Future Work

Speed Optimization: The missing piece is a custom INT8Γ—BF16 Triton kernel that multiplies INT8 weights directly with BF16 activations, accumulates in INT32, applies per-channel scaling, and outputs BF16 results. This will be integrated into cuBlade soon and should deliver 1.5-2x speedup over BF16 baseline by cutting memory bandwidth in half.

Lower Bit Widths: Once W8A16 is fast, push to W4A16 using GPTQ/AWQ-style group quantization for better quality at 4-bit, and explore mixed precision approaches where most layers use W4 but outlier-heavy layers stay at W8.


Conclusion

We've implemented LLM.int8-style W8A16 quantization in cuBlade with ~19% memory savings and 20-25% quality improvement when outliers are preserved. The current implementation prioritizes correctness over speedβ€”it's 3x slower than BF16 because we dequantize on every forward pass. But this sets the foundation: once we add a custom INT8Γ—BF16 Triton kernel, we'll get both memory efficiency and speed gains. The key insight holds: preserving just 0.1-1% of weights as outliers dramatically improves model quality, and larger models benefit even more than our Gemma-270M test case.


References


Check out the full implementation in cublade/quantization/ and run the examples in examples/quantization/.