[WIP] Before integrate the xattn operator.

This commit is contained in:
Zijie Tian
2026-01-19 21:19:21 +08:00
parent 9e6fdc0650
commit b5da802dff
11 changed files with 949 additions and 32 deletions

View File

@@ -210,6 +210,21 @@ class Attention(nn.Module):
# Apply sparse policy if enabled
sparse_policy = kvcache_manager.sparse_policy
# === XAttention BSA: Policy handles entire sparse prefill ===
# Check if policy has sparse_prefill_attention method (XAttention BSA)
if (sparse_policy is not None and
hasattr(sparse_policy, 'sparse_prefill_attention') and
getattr(sparse_policy, 'supports_prefill', False)):
# Use policy's sparse_prefill_attention method
# Pass softmax_scale from attention layer
# IMPORTANT: Don't return early - we still need to do KV offload below!
o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id, self.scale)
# Convert back to batched format for consistency with standard flow
o_acc = o.unsqueeze(0) # [seq_len, heads, dim] -> [1, seq_len, heads, dim]
lse_acc = None # sparse_prefill_attention returns final output, not intermediate LSE
# Skip standard flow processing since we already computed attention
cpu_block_table = None # Signal to skip historical chunk processing
# === Standard sparse policy (Quest, etc.) ===
if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
@@ -247,11 +262,27 @@ class Attention(nn.Module):
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
# Skip this if XAttention BSA already computed full attention (o_acc is set, lse_acc is None)
needs_current_chunk_attention = (lse_acc is not None or o_acc is None)
if needs_current_chunk_attention:
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
@@ -260,32 +291,27 @@ class Attention(nn.Module):
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
final_o = current_o
# No accumulated attention (standard flow or XAttention BSA with no historical chunks)
final_o = current_o if needs_current_chunk_attention else o_acc
else:
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
# Has accumulated attention (XAttention BSA with historical chunks)
if needs_current_chunk_attention:
# Need to merge historical (from XAttention BSA) with current chunk
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
# XAttention BSA already computed everything
final_o = o_acc
torch.cuda.nvtx.range_pop() # ChunkedPrefill