WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -229,16 +229,16 @@ class ModelRunner:
# GPU-only mode: pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass
if not config.enable_cpu_offload:
num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# if not config.enable_cpu_offload:
num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype,
device=torch.device("cuda"),
)
# Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"