[WIP] Before refactor the compute)_chunked_prefill.
This commit is contained in:
@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
return available_blocks
|
||||
|
||||
def _load_all_historical_kv(
|
||||
self,
|
||||
cpu_block_table: List[int],
|
||||
layer_id: int,
|
||||
offload_engine: "OffloadEngine",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load all historical K/V from CPU to GPU.
|
||||
|
||||
Args:
|
||||
cpu_block_table: List of CPU block IDs
|
||||
layer_id: Current layer index
|
||||
offload_engine: OffloadEngine instance
|
||||
|
||||
Returns:
|
||||
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
if not cpu_block_table:
|
||||
return None, None
|
||||
|
||||
k_list = []
|
||||
v_list = []
|
||||
|
||||
for cpu_block_id in cpu_block_table:
|
||||
k_block, v_block = offload_engine.load_block_full_from_cpu(
|
||||
cpu_block_id, layer_id
|
||||
)
|
||||
k_list.append(k_block)
|
||||
v_list.append(v_block)
|
||||
|
||||
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
|
||||
k_hist = torch.cat(k_list, dim=0)
|
||||
v_hist = torch.cat(v_list, dim=0)
|
||||
|
||||
return k_hist, v_hist
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user