[WIP] Before refactor the nanovllm sparse policy.
This commit is contained in:
@@ -210,22 +210,7 @@ class Attention(nn.Module):
|
||||
# Apply sparse policy if enabled
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
|
||||
# === XAttention BSA: Policy handles entire sparse prefill ===
|
||||
# Check if policy has sparse_prefill_attention method (XAttention BSA)
|
||||
if (sparse_policy is not None and
|
||||
hasattr(sparse_policy, 'sparse_prefill_attention') and
|
||||
getattr(sparse_policy, 'supports_prefill', False)):
|
||||
# Use policy's sparse_prefill_attention method
|
||||
# Pass softmax_scale from attention layer
|
||||
# IMPORTANT: Don't return early - we still need to do KV offload below!
|
||||
o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id, self.scale)
|
||||
# Convert back to batched format for consistency with standard flow
|
||||
o_acc = o.unsqueeze(0) # [seq_len, heads, dim] -> [1, seq_len, heads, dim]
|
||||
lse_acc = None # sparse_prefill_attention returns final output, not intermediate LSE
|
||||
# Skip standard flow processing since we already computed attention
|
||||
cpu_block_table = None # Signal to skip historical chunk processing
|
||||
|
||||
# === Standard sparse policy (Quest, etc.) ===
|
||||
# === All sparse policies use select_blocks interface ===
|
||||
if cpu_block_table and sparse_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
policy_ctx = PolicyContext(
|
||||
@@ -262,8 +247,7 @@ class Attention(nn.Module):
|
||||
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||
|
||||
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||
# Skip this if XAttention BSA already computed full attention (o_acc is set, lse_acc is None)
|
||||
needs_current_chunk_attention = (lse_acc is not None or o_acc is None)
|
||||
needs_current_chunk_attention = True
|
||||
|
||||
if needs_current_chunk_attention:
|
||||
if compute_stream is not None:
|
||||
@@ -294,24 +278,19 @@ class Attention(nn.Module):
|
||||
|
||||
# Merge with accumulated (all on compute_stream for consistency)
|
||||
if o_acc is None:
|
||||
# No accumulated attention (standard flow or XAttention BSA with no historical chunks)
|
||||
final_o = current_o if needs_current_chunk_attention else o_acc
|
||||
# No accumulated attention (no historical chunks processed)
|
||||
final_o = current_o
|
||||
else:
|
||||
# Has accumulated attention (XAttention BSA with historical chunks)
|
||||
if needs_current_chunk_attention:
|
||||
# Need to merge historical (from XAttention BSA) with current chunk
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
# Has accumulated attention (historical chunks processed)
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
# XAttention BSA already computed everything
|
||||
final_o = o_acc
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
|
||||
Reference in New Issue
Block a user