[WIP] Before add Quest policy.
This commit is contained in:
@@ -189,7 +189,8 @@ class Attention(nn.Module):
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
||||
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True)
|
||||
if cpu_block_table and prefill_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
@@ -200,7 +201,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table = prefill_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
@@ -279,7 +280,11 @@ class Attention(nn.Module):
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||
# k.shape[0] = number of tokens in current chunk
|
||||
num_valid_tokens = k.shape[0]
|
||||
offload_engine.offload_slot_layer_to_cpu(
|
||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
||||
)
|
||||
|
||||
# CRITICAL: compute_stream must wait for offload to complete
|
||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||
@@ -508,7 +513,8 @@ class Attention(nn.Module):
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False)
|
||||
if decode_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
@@ -518,7 +524,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
||||
cpu_block_table = decode_policy.select_blocks(
|
||||
cpu_block_table, policy_ctx
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user