This page looks best with JavaScript enabled

Chunked Prefill: Slicing the Prefill to Protect Decode Latency

 ·   ·  โ˜• 8 min read · ๐Ÿ‘€... views

the interference problem

continuous batching keeps the GPU busy by scheduling at iteration granularity. but one edge case breaks the latency story: long prefills.

when a request arrives with a 2048-token prompt, the scheduler runs it through prefill in a single iteration. on an A100, a 2048-token prefill for a 7B model takes roughly 200 ms. all the decode requests already in the batch are blocked for the entire duration.

time โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–บ

iter 1:  [Req A prefill: 2048 tokens โ€” 200 ms                         ]
         โ†โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’
         Req B, C, D (decode) are ALL blocked for 200 ms

iter 2:  [A dec][B dec][C dec][D dec]  โ† 5 ms
iter 3:  [A dec][B dec][C dec][D dec]  โ† 5 ms
...
Figure 1: without chunked prefill (left), a 2048-token prefill blocks all decode requests for ~200 ms, spiking TPOT 41ร—. with chunked prefill at C = 512 (right), decode runs every iteration with minimal overhead.

Figure 1: without chunked prefill (left), a 2048-token prefill blocks all decode requests for ~200 ms, spiking TPOT 41ร—. with chunked prefill at C = 512 (right), decode runs every iteration with minimal overhead.

from the perspective of Req B, C, and D, their TPOT (time per output token) just jumped from 5 ms to 205 ms for one step. users notice this as an irregular “stutter” in streaming output.

this is prefill-decode interference: the compute-bound prefill (GEMM) monopolizes the GPU, starving the latency-sensitive decode (GEMV).

the two metrics pulled in opposite directions:

optimizationTTFT (time to first token)TPOT (time per output token)
large prefill, run all at onceโ†“ low โ€” KV cache ready quicklyโ†‘ high โ€” decode blocked
delay prefill, prioritize decodeโ†‘ high โ€” new request waitsโ†“ low โ€” decode unaffected

you can’t minimize both simultaneously โ€” unless you slice the prefill.

chunked prefill: the core idea

chunked prefill splits a long prompt into segments of size \(C\) (the chunk size) and processes one segment per iteration, interleaved with decode steps:

without chunked prefill (C = 2048, full prompt):

  iter 1: [A prefill: 2048 tokens]
  iter 2: [A dec][B dec][C dec]
  iter 3: [A dec][B dec][C dec]

with chunked prefill (C = 512):

  iter 1: [A: tokens    0โ€“511] [B dec][C dec]
  iter 2: [A: tokens  512โ€“1023][B dec][C dec]
  iter 3: [A: tokens 1023โ€“1535][B dec][C dec]
  iter 4: [A: tokens 1536โ€“2047][B dec][C dec]
  iter 5: [A dec]              [B dec][C dec]  โ† A now in decode

each iteration processes a fixed token budget \(T\):

\begin{equation}
T = C_{\text{prefill}} + N_{\text{decode}}
\end{equation}

where \(C_{\text{prefill}}\) is the number of prefill tokens processed this iteration and \(N_{\text{decode}}\) is the number of running decode requests. the scheduler enforces \(C_{\text{prefill}} + N_{\text{decode}} \leq T\).

decode requests continue running at every iteration. their TPOT becomes roughly:

\begin{equation}
\text{TPOT} \approx \frac{\text{compute}(C_{\text{prefill}} + N_{\text{decode}})}{\text{compute}(N_{\text{decode}})} \times \Delta t_{\text{decode}}
\end{equation}

with \(C = 512\) and \(N_{\text{decode}} = 32\), the TPOT overhead from a prefill chunk is small: 512 tokens of GEMM adds far less than a full 2048-token prefill would.

why chunked prefill is mathematically exact

does slicing the prefill change the result? no โ€” it is exactly equivalent to running the full prefill at once. here is why.

recall that decoder-only transformers use causal attention: position \(i\) only attends to positions \(j \leq i\).

suppose the prompt is \([t_1, t_2, \ldots, t_L]\), split into chunks of size \(C\). chunk \(s\) handles tokens \([(s-1)C+1, \ldots, sC]\).

for any token \(t_i\) in chunk \(s\):

  • tokens \(j < (s-1)C + 1\) are from prior chunks โ€” their \(k_j, v_j\) were computed in earlier iterations and stored in the KV cache
  • tokens \(j \in [(s-1)C+1, i]\) are in the current chunk โ€” \(k_j, v_j\) are computed this iteration

so the attention for \(t_i\) splits into two parts:

\begin{align}
\text{attn}_{i} = \text{softmax_merge}\Bigl(
&\underbrace{\frac{q_i \cdot K_{\text{cache}}^T}{\sqrt{d_k}}}_{\text{attend to prior chunks}},;
\underbrace{\frac{q_i \cdot K_{\text{chunk}}^T}{\sqrt{d_k}}}_{\text{attend within current chunk}}
\Bigr) \cdot \begin{bmatrix} V_{\text{cache}} \ V_{\text{chunk}} \end{bmatrix}
\end{align}

where softmax_merge is the online softmax merge (same trick as paged attention’s block-level aggregation). FlashAttention’s flash_attn_varlen_func handles this natively โ€” the cu_seqlens parameter tells it each token’s effective context length (cache history + current chunk).

after each chunk, the newly computed \(k, v\) vectors are written to the KV cache:

\begin{equation}
K_{\text{cache}} \mathrel{+}= [k_{(s-1)C+1}, \ldots, k_{sC}]
\end{equation}

the next chunk sees this extended cache. by induction, after all \(\lceil L/C \rceil\) chunks, the KV cache contains exactly what a full-prompt prefill would have produced โ€” the final decode step is indistinguishable from the non-chunked case.

TTFT/TPOT tradeoff and chunk size selection

chunk size \(C\) is the key tuning knob:

\begin{equation}
\text{TTFT} \approx \left\lceil \frac{L_{\text{prompt}}}{C} \right\rceil \times \Delta t_{\text{iter}}
\end{equation}

\begin{equation}
\text{TPOT jitter} \propto \frac{C}{N_{\text{decode}}} \times \frac{\text{FLOP}_{\text{GEMM}}}{\text{FLOP}_{\text{GEMV}}}
\end{equation}

  • larger \(C\): fewer iterations to complete prefill โ†’ lower TTFT. but each prefill chunk is larger โ†’ more decode interference per iteration โ†’ higher TPOT jitter.
  • smaller \(C\): decode runs with minimal interference โ†’ stable TPOT. but prefill needs more iterations โ†’ higher TTFT.

the sweet spot depends on the ratio of active decode requests to prefill tokens. typical production defaults:

enginedefault chunk size
vLLM (v0.4+)512 tokens
SGLang512 tokens
TensorRT-LLM1024 tokens

with \(C = 512\) and a 2048-token prompt: 4 iterations to complete prefill, each adding just 512 tokens of GEMM overhead to the decode step. total TTFT increase versus full-prefill: \(3 \times 5\text{ ms} = 15\text{ ms}\) โ€” negligible for most use cases.

FLOPs analysis: no overhead from chunking

an important sanity check: does chunking add FLOPs? the answer is no.

for a single transformer layer, attention FLOPs processing \(L\) tokens is:

\begin{equation}
\text{FLOP}_{\text{attn}}(L) = 4L^2 d + 4Ld^2
\end{equation}

the \(4L^2 d\) term comes from \(QK^T\) and \(\text{attn} \cdot V\); \(4Ld^2\) from the four projection matrices.

without chunking: one call with \(L\) tokens.

with chunking: \(\lceil L/C \rceil\) calls, each processing \(C\) prompt tokens against an expanding KV cache. the total attention FLOPs across all chunks:

\begin{align}
\text{FLOP}_{\text{chunk-attn}} &= 4d^2 \sum_{s=1}^{L/C} C + 4d \sum_{s=1}^{L/C} C \cdot (sC) \\
&= 4Ld^2 + 4d \cdot C^2 \cdot \frac{(L/C)(L/C + 1)}{2} \\
&\approx 4Ld^2 + 4d \cdot \frac{L^2}{2} = 4L^2 d + 4Ld^2
\end{align}

exactly the same as the non-chunked case. chunking distributes the same FLOPs across more iterations โ€” it does not add computation.

IO overhead: negligible in practice

the one real cost is writing new KV vectors to HBM at the end of each chunk. for chunk size \(C\), model with \(n_h\) KV heads, head dim \(d_h\), \(L_{\text{layers}}\) layers, and BF16:

\begin{equation}
\text{write per chunk} = C \times 2 \times L_{\text{layers}} \times n_h \times d_h \times 2 \text{ bytes}
\end{equation}

for LLaMA-3 8B (\(L = 32, n_h = 8, d_h = 128\)) and \(C = 512\):

\begin{equation}
512 \times 2 \times 32 \times 8 \times 128 \times 2 = 67{,}108{,}864 \text{ bytes} \approx 64 \text{ MB}
\end{equation}

at A100’s HBM bandwidth of ~2 TB/s:

\begin{equation}
\frac{64 \times 10^6}{2 \times 10^{12}} = 32 \text{ ฮผs}
\end{equation}

32 microseconds, compared to an iteration time of ~5 ms. the IO overhead is <1% of iteration cost โ€” genuinely negligible.

interaction with prefix caching

chunked prefill and prefix caching compose nicely. if the first \(k\) blocks of a prompt are already in the cache, those blocks are skipped entirely:

prompt: [system prompt โ€” 1024 tokens][user query โ€” 1024 tokens]
             (cached โ€” skip)              (must compute)

with prefix cache + chunked prefill (C = 512):
  iter 1: [user query tokens   0โ€“511]   โ† only 2 chunks instead of 4
  iter 2: [user query tokens 512โ€“1023]
  iter 3: [decode]

the effective prefill length after a cache hit is only the uncached suffix. TTFT drops further, and fewer iterations are consumed for prefill.

scheduler implementation

the scheduling logic in SGLang looks roughly like:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def schedule(self):
    budget = self.token_budget  # e.g., 2048 tokens

    # 1. running decode requests each consume 1 token
    for req in self.running:
        budget -= 1

    # 2. prefill requests consume up to chunk_size tokens
    for req in self.waiting:
        chunk = min(req.remaining_prefill, budget, self.chunk_size)
        if chunk == 0:
            break
        req.prefill_this_iter = chunk
        budget -= chunk

    return self.running + [r for r in self.waiting if r.prefill_this_iter > 0]

the key property: prefill_this_iter can be less than remaining_prefill, capturing the “partial” prefill. the next iteration the scheduler picks up where it left off.

comparison with disaggregated prefill

chunked prefill is the in-place solution to prefill-decode interference: both still share the same GPU, just interleaved more carefully.

disaggregated prefill takes the more radical approach: route prefill to separate machines entirely, so decode GPUs never see prefill traffic at all.

dimensionchunked prefilldisaggregated prefill
hardware requirementsingle GPU / nodeseparate prefill + decode pools
TPOTsignificantly improvedoptimal (zero interference)
TTFTslightly increased (chunking adds iterations)significantly better (dedicated prefill resources)
network overheadnoneKV cache migration across nodes
implementation complexitylow (scheduler-only change)high (cluster coordination)
when to usegeneral production servinglarge-scale, SLO-critical deployments

a full treatment of disaggregated prefill is in the next post.

summary

chunked prefill is arguably the best-value optimization in LLM serving:

  • zero FLOPs overhead โ€” chunking distributes the same work, not more work
  • negligible IO overhead โ€” ~32 ฮผs of KV writes per chunk vs 5 ms iteration time
  • straightforward implementation โ€” only the scheduler changes; the attention kernel is unmodified
  • significant TPOT improvement โ€” decode requests no longer stall for long prefills
  • composable โ€” works naturally with prefix caching (skip cached chunks), paged attention (KV written block by block), and continuous batching (same iteration-level loop)

the only cost is a modest increase in TTFT proportional to \(C_{\text{chunk}} \times \lceil L/C \rceil\), which in practice is well within acceptable bounds.

Share on