[feat] Added sparse KVcache feature, NEED VERIFY.
This commit is contained in:
@@ -6,6 +6,7 @@ import triton.language as tl
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -133,6 +134,22 @@ class Attention(nn.Module):
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=self.layer_id,
|
||||
query=None, # Prefill typically doesn't use query for selection
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
@@ -344,6 +361,21 @@ class Attention(nn.Module):
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=self.layer_id,
|
||||
query=q_batched, # Decode provides query for query-aware selection
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Use prefetch_size as chunk size for double buffering
|
||||
|
||||
Reference in New Issue
Block a user