[WIP] Before refactor the compute)_chunked_prefill.

This commit is contained in:
Zijie Tian
2026-01-23 03:36:12 +08:00
parent edc006463b
commit ca32ea6f93
7 changed files with 914 additions and 114 deletions

View File

@@ -48,7 +48,7 @@ class Config:
# XAttention BSA specific parameters
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling

View File

@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
"""
return available_blocks
def _load_all_historical_kv(
self,
cpu_block_table: List[int],
layer_id: int,
offload_engine: "OffloadEngine",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load all historical K/V from CPU to GPU.
Args:
cpu_block_table: List of CPU block IDs
layer_id: Current layer index
offload_engine: OffloadEngine instance
Returns:
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
"""
if not cpu_block_table:
return None, None
k_list = []
v_list = []
for cpu_block_id in cpu_block_table:
k_block, v_block = offload_engine.load_block_full_from_cpu(
cpu_block_id, layer_id
)
k_list.append(k_block)
v_list.append(v_block)
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
k_hist = torch.cat(k_list, dim=0)
v_hist = torch.cat(v_list, dim=0)
return k_hist, v_hist
def compute_chunked_prefill(
self,
q: torch.Tensor,

View File

@@ -258,8 +258,12 @@ class Attention(nn.Module):
raise RuntimeError("sparse_policy is required for chunked decode")
# Check if policy supports decode phase
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase")
from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "