[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -188,9 +188,9 @@ class Attention(nn.Module):
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
if cpu_block_table and prefill_policy is not None:
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
sparse_policy = kvcache_manager.sparse_policy
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
@@ -201,7 +201,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = prefill_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
@@ -512,9 +512,9 @@ class Attention(nn.Module):
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
if decode_policy is not None:
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
@@ -524,7 +524,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = decode_policy.select_blocks(
cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)