✨ feat(xattn): implement compute_chunked_prefill with ring buffer pipeline
- Copy compute_chunked_prefill implementation from FullAttentionPolicy - Set default threshold to 0.95 for accuracy testing - Remove debug code (sys.exit, verbose prints) - Use ring buffer pipeline for historical block loading - Merge with current chunk attention using flash_attn_with_lse RULER NIAH test passed with 5/5 samples (100% accuracy). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -88,7 +88,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.1, # Very low threshold for aggressive sparsity testing
|
||||
threshold: float = 0.95, # High threshold for accuracy testing
|
||||
stride: int = 8,
|
||||
chunk_size: int = 16384,
|
||||
block_size: int = 128,
|
||||
@@ -298,27 +298,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_selected = vote_ratio > vote_threshold
|
||||
selected_block_ids = [available_blocks[i] for i, sel in enumerate(block_selected.tolist()) if sel]
|
||||
|
||||
# Compute density = selected / total
|
||||
density = len(selected_block_ids) / len(available_blocks) if available_blocks else 1.0
|
||||
|
||||
# Debug output: show block selection results
|
||||
if layer_id == 0: # Only log for layer 0 to avoid spam
|
||||
# Count True per head to see head-level sparsity
|
||||
# mask shape: [batch, num_heads, q_blocks, k_blocks]
|
||||
per_head_selected = mask[0, :, 0, :].sum(dim=-1) # [num_heads] - selected blocks per head
|
||||
per_head_density = per_head_selected.float() / k_blocks
|
||||
|
||||
print(f"[XAttn DEBUG] chunk={ctx.query_chunk_idx}, "
|
||||
f"blocks={len(available_blocks)}, "
|
||||
f"final_selected={len(selected_block_ids)}, "
|
||||
f"final_density={density:.1%}, "
|
||||
f"per_head_density={[f'{d:.0%}' for d in per_head_density[:8].tolist()]}...") # First 8 heads
|
||||
|
||||
# Exit early after 40 chunks for faster debugging
|
||||
if ctx.query_chunk_idx >= 40:
|
||||
print(f"[XAttn DEBUG] Exiting early after {ctx.query_chunk_idx} chunks for debugging")
|
||||
import sys
|
||||
sys.exit(0)
|
||||
# Log density for layer 0 only
|
||||
if layer_id == 0:
|
||||
density = len(selected_block_ids) / len(available_blocks) if available_blocks else 1.0
|
||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, blocks={len(available_blocks)}, "
|
||||
f"selected={len(selected_block_ids)}, density={density:.1%}")
|
||||
|
||||
# Always include first block (sink) and last block for safety
|
||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||
@@ -345,12 +329,15 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
Compute attention for chunked prefill using XAttention sparse block selection.
|
||||
|
||||
TODO: Implement sparse attention computation using selected_blocks.
|
||||
This method handles the chunked prefill computation:
|
||||
1. Load and compute attention to historical chunks (using selected_blocks)
|
||||
2. Compute attention to current chunk
|
||||
3. Merge all results
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (current chunk)
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim] (current chunk)
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||
layer_id: Current layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
@@ -363,9 +350,94 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
# TODO: Implement sparse attention with selected_blocks
|
||||
# For now, return zeros as placeholder
|
||||
return torch.zeros_like(q)
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
if cpu_block_table:
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
num_blocks = len(cpu_block_table)
|
||||
|
||||
if len(load_slots) == 1:
|
||||
# Only 1 slot - use synchronous mode
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
else:
|
||||
# Multiple slots - use pipeline
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Issue next transfer
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
|
||||
# Compute attention to current chunk (causal mask)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched, k_curr, v_curr,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Merge historical and current attention
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
|
||||
# Sync default stream with compute_stream before returning
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user