[feat] Added chunked prefill and kvcache offload mechenism.
This commit is contained in:
@@ -55,21 +55,164 @@ class Attention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
# Layer ID set by model_runner after model creation
|
||||
self.layer_id: int = -1
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked prefill: merge attention from previous KV
|
||||
o = self._chunked_prefill_attention(q, k, v, context)
|
||||
elif context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
return o
|
||||
|
||||
def _chunked_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with chunked KV from CPU cache.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU for this layer
|
||||
2. Compute attention against previous KV (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
total_tokens = q.shape[0]
|
||||
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
# Load previous KV from CPU for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
# Get the kvcache_manager from context
|
||||
kvcache_manager = context.offload_engine
|
||||
|
||||
# For each sequence in the chunk, load previous KV
|
||||
# Currently assuming single sequence
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if prev_k is not None and prev_v is not None:
|
||||
# Compute attention against previous KV (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for previous context
|
||||
)
|
||||
accumulated_o = prev_o
|
||||
accumulated_lse = prev_lse
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True, # Causal mask for current chunk
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
current_o, current_lse,
|
||||
)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def _chunked_decode_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with KV spread across CPU and GPU.
|
||||
|
||||
For decode with chunked KV:
|
||||
1. Load all KV for this layer from CPU+GPU
|
||||
2. Compute attention (1 query token vs all KV)
|
||||
3. Return output
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
# We need to attend to ALL previous tokens
|
||||
|
||||
# Load all KV for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
kvcache_manager = context.offload_engine
|
||||
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
# Load all KV from both GPU and CPU for this layer
|
||||
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if k_all is not None and v_all is not None:
|
||||
# q shape: [batch_size, num_heads, head_dim]
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
# Insert seqlen dimension at position 1
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim]
|
||||
# Compute attention (no causal mask for decode - we want all KV)
|
||||
out, _ = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_all,
|
||||
v_all,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for decode
|
||||
)
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return out.squeeze(1)
|
||||
|
||||
# Fallback: shouldn't reach here
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
Reference in New Issue
Block a user