[WIP] Before refactor the nanovllm sparse policy.

This commit is contained in:
Zijie Tian
2026-01-19 22:34:44 +08:00
parent b5da802dff
commit b97b0b96a0
8 changed files with 475 additions and 837 deletions

View File

@@ -210,22 +210,7 @@ class Attention(nn.Module):
# Apply sparse policy if enabled
sparse_policy = kvcache_manager.sparse_policy
# === XAttention BSA: Policy handles entire sparse prefill ===
# Check if policy has sparse_prefill_attention method (XAttention BSA)
if (sparse_policy is not None and
hasattr(sparse_policy, 'sparse_prefill_attention') and
getattr(sparse_policy, 'supports_prefill', False)):
# Use policy's sparse_prefill_attention method
# Pass softmax_scale from attention layer
# IMPORTANT: Don't return early - we still need to do KV offload below!
o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id, self.scale)
# Convert back to batched format for consistency with standard flow
o_acc = o.unsqueeze(0) # [seq_len, heads, dim] -> [1, seq_len, heads, dim]
lse_acc = None # sparse_prefill_attention returns final output, not intermediate LSE
# Skip standard flow processing since we already computed attention
cpu_block_table = None # Signal to skip historical chunk processing
# === Standard sparse policy (Quest, etc.) ===
# === All sparse policies use select_blocks interface ===
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
@@ -262,8 +247,7 @@ class Attention(nn.Module):
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
# Skip this if XAttention BSA already computed full attention (o_acc is set, lse_acc is None)
needs_current_chunk_attention = (lse_acc is not None or o_acc is None)
needs_current_chunk_attention = True
if needs_current_chunk_attention:
if compute_stream is not None:
@@ -294,24 +278,19 @@ class Attention(nn.Module):
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
# No accumulated attention (standard flow or XAttention BSA with no historical chunks)
final_o = current_o if needs_current_chunk_attention else o_acc
# No accumulated attention (no historical chunks processed)
final_o = current_o
else:
# Has accumulated attention (XAttention BSA with historical chunks)
if needs_current_chunk_attention:
# Need to merge historical (from XAttention BSA) with current chunk
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
# Has accumulated attention (historical chunks processed)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
# XAttention BSA already computed everything
final_o = o_acc
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill