KVCache and Prefill phase in LLMs

1 minute read

Published:

Simple explanation of KVCache and Prefill phase in LLMs

The Inference

  1. The initial sentence “what is a LLM?” is tokenized ((also commonly referred as context + user prompt))
  2. Each token is converted to embeddings of dimension 768 (GPT-2 for example).
  3. Forwarded to the transformer block to Contextualize the embeddings

The Problem

During Inference the attention block needs to compute a scaled dot product of the Queries, Keys and Values

📔: Materialise is a fancy way of saying that you need to allocate memory for a Tensor of 1000x1000 with each dimension being 768 🙂

Enter KVCache

  1. The trick is to identify what is important in each timestep of the decoding phase (i.e the Inference) Turns out only the last token forwarded to the Attention block contains the context of all previous tokens We sure can exploit this
  2. Use a list datastructure to store previous timestep states

What is Prefill Phase

  1. This is the first step of Inference at T=1 ( T>1 is the decoding phase )
  2. All of the input (context + prompt) is forwarded to the Model
  3. Unlike the Decoding phase there is no “for loop”

Prefill Phase

Transformer is sequence to sequence model. In simple terms it means, if 6 words are forwarded to the Model. The model will output 6 next words as shown below.

Adding to the Cache

  # concatenate to list
  self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

Using the Cache

  key_states = key_states.view(bsz, q-len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  if kv_cache is not None: 
    # if the list is present then concatenate to it and use that whole list
    key_states, value_states = kv_cache.update(keys_states, value_states, self.layer_idx)