[feat] Added Quest Sparsity Policy.
This commit is contained in:
@@ -188,9 +188,9 @@ class Attention(nn.Module):
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
||||
if cpu_block_table and prefill_policy is not None:
|
||||
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if cpu_block_table and sparse_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
@@ -201,7 +201,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = prefill_policy.select_blocks(
|
||||
cpu_block_table = sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
@@ -512,9 +512,9 @@ class Attention(nn.Module):
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
||||
if decode_policy is not None:
|
||||
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
@@ -524,7 +524,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = decode_policy.select_blocks(
|
||||
cpu_block_table = sparse_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user