[feat] Added sparse KVcache feature, NEED VERIFY.

This commit is contained in:
Zijie Tian
2025-12-22 08:51:02 +08:00
parent 8df0c7517b
commit 051f2295c9
14 changed files with 1215 additions and 12 deletions

View File

@@ -6,6 +6,7 @@ import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@@ -133,6 +134,22 @@ class Attention(nn.Module):
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
if cpu_block_table and kvcache_manager.sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
@@ -344,6 +361,21 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q_batched, # Decode provides query for query-aware selection
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
offload_engine = kvcache_manager.offload_engine
# Use prefetch_size as chunk size for double buffering