[WIP] Before add Quest policy.

This commit is contained in:
Zijie Tian
2026-01-07 02:32:30 +08:00
parent f240903013
commit c99a6f3d3f
11 changed files with 166 additions and 191 deletions

View File

@@ -189,7 +189,8 @@ class Attention(nn.Module):
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
if cpu_block_table and kvcache_manager.sparse_policy is not None:
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
if cpu_block_table and prefill_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
@@ -200,7 +201,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = prefill_policy.select_blocks(
cpu_block_table, policy_ctx
)
@@ -279,7 +280,11 @@ class Attention(nn.Module):
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# k.shape[0] = number of tokens in current chunk
num_valid_tokens = k.shape[0]
offload_engine.offload_slot_layer_to_cpu(
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
)
# CRITICAL: compute_stream must wait for offload to complete
# before the next layer's store_kvcache can overwrite the GPU slot.
@@ -508,7 +513,8 @@ class Attention(nn.Module):
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
if decode_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
@@ -518,7 +524,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table = decode_policy.select_blocks(
cpu_block_table, policy_ctx
)