WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1
Q[0, 0, i, :] = 1 * (i // stride + 1)
else:
Q[0, 0, i, :] = 2
Q[0, 0, i, :] = 2 * (i // stride + 1)
for i in range(kv_len):
if i % 2 == 0:
@@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks):
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
is_causal=True
)
__import__('pdb').set_trace()
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果