WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment. - Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration. - Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection. - Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks. - Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling. - Introduced DensityObserver to track compute and communication density for sparse attention layers. - Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios. - Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
@@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1
|
||||
Q[0, 0, i, :] = 1 * (i // stride + 1)
|
||||
else:
|
||||
Q[0, 0, i, :] = 2
|
||||
Q[0, 0, i, :] = 2 * (i // stride + 1)
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
@@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks):
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
|
||||
Reference in New Issue
Block a user