[WIP] Before refactor the compute)_chunked_prefill.

This commit is contained in:
Zijie Tian
2026-01-23 03:36:12 +08:00
parent edc006463b
commit ca32ea6f93
7 changed files with 914 additions and 114 deletions

View File

@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
"""
return available_blocks
def _load_all_historical_kv(
self,
cpu_block_table: List[int],
layer_id: int,
offload_engine: "OffloadEngine",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load all historical K/V from CPU to GPU.
Args:
cpu_block_table: List of CPU block IDs
layer_id: Current layer index
offload_engine: OffloadEngine instance
Returns:
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
"""
if not cpu_block_table:
return None, None
k_list = []
v_list = []
for cpu_block_id in cpu_block_table:
k_block, v_block = offload_engine.load_block_full_from_cpu(
cpu_block_id, layer_id
)
k_list.append(k_block)
v_list.append(v_block)
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
k_hist = torch.cat(k_list, dim=0)
v_hist = torch.cat(v_list, dim=0)
return k_hist, v_hist
def compute_chunked_prefill(
self,
q: torch.Tensor,