[WIP] Before refactor the compute)_chunked_prefill.
This commit is contained in:
@@ -48,7 +48,7 @@ class Config:
|
||||
# XAttention BSA specific parameters
|
||||
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
|
||||
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
|
||||
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
|
||||
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
||||
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
||||
sparse_stride: int = 8 # Stride for Q/K downsampling
|
||||
|
||||
|
||||
@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
return available_blocks
|
||||
|
||||
def _load_all_historical_kv(
|
||||
self,
|
||||
cpu_block_table: List[int],
|
||||
layer_id: int,
|
||||
offload_engine: "OffloadEngine",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load all historical K/V from CPU to GPU.
|
||||
|
||||
Args:
|
||||
cpu_block_table: List of CPU block IDs
|
||||
layer_id: Current layer index
|
||||
offload_engine: OffloadEngine instance
|
||||
|
||||
Returns:
|
||||
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
if not cpu_block_table:
|
||||
return None, None
|
||||
|
||||
k_list = []
|
||||
v_list = []
|
||||
|
||||
for cpu_block_id in cpu_block_table:
|
||||
k_block, v_block = offload_engine.load_block_full_from_cpu(
|
||||
cpu_block_id, layer_id
|
||||
)
|
||||
k_list.append(k_block)
|
||||
v_list.append(v_block)
|
||||
|
||||
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
|
||||
k_hist = torch.cat(k_list, dim=0)
|
||||
v_hist = torch.cat(v_list, dim=0)
|
||||
|
||||
return k_hist, v_hist
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -258,8 +258,12 @@ class Attention(nn.Module):
|
||||
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||
|
||||
# Check if policy supports decode phase
|
||||
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||
sparse_policy = FullAttentionPolicy()
|
||||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||
f"falling back to FullAttentionPolicy")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||
|
||||
Reference in New Issue
Block a user