introduction
softmax is one of the most ubiquitous operations in deep learning. it appears in attention mechanisms, classification heads, and anywhere we need to normalize a vector into a probability distribution.
the softmax function for a vector \(x\) of length \(N\) is:
\begin{equation}
\text{softmax}(x_i) = \frac{\exp(x_i - \max(x))}{\sum_{j=1}^{N} \exp(x_j - \max(x))}
\end{equation}
we subtract \(\max(x)\) for numerical stability โ without it, \(\exp(x_i)\) can overflow for large \(x_i\).
for a matrix of shape \(M \times N\), softmax is applied row-wise. this means each of the \(M\) rows is independently normalized.
the memory bottleneck
a naive PyTorch implementation decomposes softmax into several separate operations:
| |
total memory traffic: \(5MN + 2M\) reads and \(3MN + 2M\) writes. each intermediate result is written to global memory and read back again.
a fused kernel can reduce this to just \(MN\) reads and \(MN\) writes โ a theoretical 4x reduction in memory traffic. the idea is simple: keep each row in GPU SRAM (shared memory / L2 cache), perform all computations on it, and write the result back once.
Figure 1: memory access comparison โ naive softmax makes multiple trips to global memory, while fused softmax processes each row entirely in fast on-chip memory
sequenceDiagram
participant GM as Global Memory
participant GC as GPU / SRAM
note over GM,GC: Naive โ 10 global memory trips
GM->>GC: โ read row (find max)
GC->>GM: โก write max (M values)
GM->>GC: โข read row (subtract max)
GC->>GM: โฃ write shifted row (MN values)
GM->>GC: โค read row (exp)
GC->>GM: โฅ write exp row (MN values)
GM->>GC: โฆ read row (find sum)
GC->>GM: โง write sum (M values)
GM->>GC: โจ read row (divide)
GC->>GM: โฉ write output (MN values)
note over GM,GC: Fused โ 2 global memory trips
GM->>GC: โ read row once (MN values)
note over GC: max โ sub โ exp โ sum โ div (all in SRAM)
GC->>GM: โก write output once (MN values)
the triton kernel
triton makes it straightforward to write fused kernels. the key insight: assign each GPU thread block to one or more rows, load the entire row into SRAM, compute max, exp, and sum in registers, then write back.
| |
the kernel processes each row in three phases โ max, exp, and sum โ all performed on data resident in fast on-chip memory. no intermediate results are written to global memory.
kernel launch and occupancy tuning
the wrapper function computes optimal launch parameters based on the matrix shape and hardware characteristics:
| |
why fusion works
the speedup comes from eliminating redundant memory traffic, not from faster arithmetic. to understand this, consider the memory bandwidth bottleneck:
| metric | naive PyTorch | fused Triton |
|---|---|---|
| global memory reads | \(5MN + 2M\) | \(MN\) |
| global memory writes | \(3MN + 2M\) | \(MN\) |
| total traffic | \(8MN + 4M\) | \(2MN\) |
for large matrices, the factor approaches 4x. GPU compute units are fast โ the bottleneck is almost always memory bandwidth, not FLOPs.
performance results
benchmarking on an \(M = 4096\) row matrix with varying column sizes:
Figure 2: performance comparison across different column sizes โ triton fused softmax consistently outperforms both naive and torch.softmax implementations
{
"type": "line",
"data": {
"labels": ["256", "1024", "4096", "16384", "65536", "262144"],
"datasets": [
{
"label": "Naive PyTorch",
"data": [20, 40, 65, 80, 95, 100],
"borderColor": "#e05252",
"backgroundColor": "transparent",
"borderWidth": 2,
"pointRadius": 4
},
{
"label": "torch.softmax",
"data": [12, 22, 40, 55, 70, 78],
"borderColor": "#f0a500",
"backgroundColor": "transparent",
"borderWidth": 2,
"pointRadius": 4
},
{
"label": "Triton Fused",
"data": [5, 10, 18, 25, 32, 38],
"borderColor": "#4caf50",
"backgroundColor": "transparent",
"borderWidth": 2,
"pointRadius": 4
}
]
},
"options": {
"title": {
"display": true,
"text": "Softmax Performance (M = 4096 rows)"
},
"scales": {
"xAxes": [{"scaleLabel": {"display": true, "labelString": "Matrix columns (N)"}}],
"yAxes": [{"scaleLabel": {"display": true, "labelString": "Time (us) โ lower is better"}, "ticks": {"min": 0}}]
}
}
}
key findings:
- triton is approximately 4x faster than the naive torch JIT implementation
- triton outperforms
torch.softmaxacross most matrix sizes - memory bandwidth utilization reaches up to 1448 GB/s for triton vs 1515 GB/s for PyTorch at peak
the triton kernel achieves near-peak memory bandwidth because it reads each element once and writes it once โ the theoretical minimum for this operation.
limitations
the fused softmax approach works best when each row fits in GPU SRAM. for very wide matrices (large \(N\)), the row may exceed shared memory capacity, requiring a different tiling strategy.
for such cases, triton’s online softmax technique can process rows in chunks, trading a small amount of extra computation for the ability to handle arbitrarily large rows while still avoiding redundant global memory access.
summary
- naive softmax writes intermediate results (max, exp, sum) to global memory, causing \(O(MN)\) redundant reads and writes
- fused softmax keeps the entire row in fast on-chip memory, reducing memory traffic by ~4x
- triton makes it easy to write fused kernels with a python-like syntax, while automatically handling register allocation and shared memory management
- the key to performance is not faster arithmetic but reducing memory bandwidth โ the real bottleneck on modern GPUs
the full source code and benchmark scripts are available in the triton tutorials.