[WIP] Before integrate the xattn operator.
This commit is contained in:
@@ -210,6 +210,21 @@ class Attention(nn.Module):
|
||||
# Apply sparse policy if enabled
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
|
||||
# === XAttention BSA: Policy handles entire sparse prefill ===
|
||||
# Check if policy has sparse_prefill_attention method (XAttention BSA)
|
||||
if (sparse_policy is not None and
|
||||
hasattr(sparse_policy, 'sparse_prefill_attention') and
|
||||
getattr(sparse_policy, 'supports_prefill', False)):
|
||||
# Use policy's sparse_prefill_attention method
|
||||
# Pass softmax_scale from attention layer
|
||||
# IMPORTANT: Don't return early - we still need to do KV offload below!
|
||||
o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id, self.scale)
|
||||
# Convert back to batched format for consistency with standard flow
|
||||
o_acc = o.unsqueeze(0) # [seq_len, heads, dim] -> [1, seq_len, heads, dim]
|
||||
lse_acc = None # sparse_prefill_attention returns final output, not intermediate LSE
|
||||
# Skip standard flow processing since we already computed attention
|
||||
cpu_block_table = None # Signal to skip historical chunk processing
|
||||
|
||||
# === Standard sparse policy (Quest, etc.) ===
|
||||
if cpu_block_table and sparse_policy is not None:
|
||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||
@@ -247,11 +262,27 @@ class Attention(nn.Module):
|
||||
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||
|
||||
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Skip this if XAttention BSA already computed full attention (o_acc is set, lse_acc is None)
|
||||
needs_current_chunk_attention = (lse_acc is not None or o_acc is None)
|
||||
|
||||
if needs_current_chunk_attention:
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
# Get KV from per-layer prefill buffer
|
||||
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
# Get KV from per-layer prefill buffer
|
||||
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -260,32 +291,27 @@ class Attention(nn.Module):
|
||||
causal=True,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Merge with accumulated (all on compute_stream for consistency)
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
# No accumulated attention (standard flow or XAttention BSA with no historical chunks)
|
||||
final_o = current_o if needs_current_chunk_attention else o_acc
|
||||
else:
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Has accumulated attention (XAttention BSA with historical chunks)
|
||||
if needs_current_chunk_attention:
|
||||
# Need to merge historical (from XAttention BSA) with current chunk
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
# XAttention BSA already computed everything
|
||||
final_o = o_acc
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
|
||||
Reference in New Issue
Block a user