WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment. - Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration. - Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection. - Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks. - Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling. - Introduced DensityObserver to track compute and communication density for sparse attention layers. - Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios. - Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
"""
|
||||
DensityObserver - Sparse Attention Density 统计 Observer。
|
||||
|
||||
统计每层的 sparse attention density:
|
||||
- density = selected_blocks / total_causal_blocks
|
||||
- 在 causal attention 下,只计算下三角区域
|
||||
统计两种 density:
|
||||
1. Compute Density (计算密度): 基于 BSA block size (128)
|
||||
- density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||
- GPU-only 和 Offload 模式应该一致
|
||||
|
||||
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
|
||||
- comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
|
||||
|
||||
统计位置:
|
||||
- GPU-only: xattn_bsa.py compute_prefill()
|
||||
- Offload: xattn_bsa.py select_blocks()
|
||||
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
|
||||
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
|
||||
|
||||
对于 Offload 模式的 Density 计算:
|
||||
- 不是简单的 avg 或 min
|
||||
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
@@ -26,16 +35,26 @@ class DensityObserver(Observer):
|
||||
DensityObserver.complete_reset()
|
||||
# ... run inference ...
|
||||
DensityObserver.record(layer_id, mask, causal=True)
|
||||
# 或者使用累积模式 (offload):
|
||||
DensityObserver.record_counts(layer_id, selected, total)
|
||||
# ...
|
||||
DensityObserver.print_summary()
|
||||
"""
|
||||
|
||||
_enabled: bool = False # 默认禁用
|
||||
|
||||
# 每层的 density 记录
|
||||
# 每层的 compute density 记录 (BSA block 粒度)
|
||||
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
|
||||
_layer_densities: Dict[int, List[float]] = {}
|
||||
|
||||
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
|
||||
_layer_comm_densities: Dict[int, List[float]] = {}
|
||||
|
||||
# 累积模式: 记录 selected/total counts (用于 offload 模式)
|
||||
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
|
||||
_layer_selected_counts: Dict[int, List[int]] = {}
|
||||
_layer_total_counts: Dict[int, List[int]] = {}
|
||||
|
||||
# Mask shape 记录 (用于调试)
|
||||
_last_q_blocks: int = 0
|
||||
_last_k_blocks: int = 0
|
||||
@@ -56,7 +75,7 @@ class DensityObserver(Observer):
|
||||
causal: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
记录一层的 density。
|
||||
记录一层的 density (适用于 GPU-only 模式)。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
@@ -82,6 +101,72 @@ class DensityObserver(Observer):
|
||||
|
||||
return density
|
||||
|
||||
@classmethod
|
||||
def record_counts(
|
||||
cls,
|
||||
layer_id: int,
|
||||
selected_blocks: int,
|
||||
total_blocks: int,
|
||||
) -> None:
|
||||
"""
|
||||
记录一层的 selected/total block counts (适用于 offload 累积模式)。
|
||||
|
||||
使用累积计数而不是直接计算 density,这样在所有 chunks 处理完后可以正确计算:
|
||||
overall_density = sum(selected) / sum(total)
|
||||
|
||||
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
selected_blocks: 这个 chunk 选中的 blocks 数量
|
||||
total_blocks: 这个 chunk 的 total possible blocks 数量
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return
|
||||
|
||||
# 初始化列表
|
||||
if layer_id not in cls._layer_selected_counts:
|
||||
cls._layer_selected_counts[layer_id] = []
|
||||
if layer_id not in cls._layer_total_counts:
|
||||
cls._layer_total_counts[layer_id] = []
|
||||
|
||||
# 累积记录
|
||||
cls._layer_selected_counts[layer_id].append(selected_blocks)
|
||||
cls._layer_total_counts[layer_id].append(total_blocks)
|
||||
|
||||
@classmethod
|
||||
def record_comm_density(
|
||||
cls,
|
||||
layer_id: int,
|
||||
selected_cpu_blocks: int,
|
||||
total_cpu_blocks: int,
|
||||
) -> float:
|
||||
"""
|
||||
记录一层的 communication density (CPU block 粒度)。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
selected_cpu_blocks: 选中的 CPU blocks 数量
|
||||
total_cpu_blocks: 总 CPU blocks 数量
|
||||
|
||||
Returns:
|
||||
communication density 值
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return 0.0
|
||||
|
||||
if total_cpu_blocks == 0:
|
||||
return 1.0
|
||||
|
||||
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||
|
||||
# 记录
|
||||
if layer_id not in cls._layer_comm_densities:
|
||||
cls._layer_comm_densities[layer_id] = []
|
||||
cls._layer_comm_densities[layer_id].append(comm_density)
|
||||
|
||||
return comm_density
|
||||
|
||||
@classmethod
|
||||
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
|
||||
"""计算 mask 的 density"""
|
||||
@@ -107,22 +192,63 @@ class DensityObserver(Observer):
|
||||
def complete_reset(cls) -> None:
|
||||
"""重置所有统计"""
|
||||
cls._layer_densities = {}
|
||||
cls._layer_comm_densities = {}
|
||||
cls._layer_selected_counts = {}
|
||||
cls._layer_total_counts = {}
|
||||
cls._last_q_blocks = 0
|
||||
cls._last_k_blocks = 0
|
||||
cls._mode = "unknown"
|
||||
|
||||
@classmethod
|
||||
def get_per_layer_density(cls) -> Dict[int, float]:
|
||||
"""获取每层的平均 density"""
|
||||
"""
|
||||
获取每层的 density。
|
||||
|
||||
对于累积模式 (offload): density = sum(selected) / sum(total)
|
||||
对于直接记录模式 (gpu_only): density = avg(density_values)
|
||||
"""
|
||||
result = {}
|
||||
for layer_id, densities in cls._layer_densities.items():
|
||||
if densities:
|
||||
result[layer_id] = sum(densities) / len(densities)
|
||||
|
||||
# 优先使用累积模式 (offload)
|
||||
if cls._layer_selected_counts:
|
||||
for layer_id in cls._layer_selected_counts:
|
||||
selected_list = cls._layer_selected_counts.get(layer_id, [])
|
||||
total_list = cls._layer_total_counts.get(layer_id, [])
|
||||
total_selected = sum(selected_list)
|
||||
total_total = sum(total_list)
|
||||
if total_total > 0:
|
||||
result[layer_id] = total_selected / total_total
|
||||
else:
|
||||
# 直接记录模式 (gpu_only)
|
||||
for layer_id, densities in cls._layer_densities.items():
|
||||
if densities:
|
||||
result[layer_id] = sum(densities) / len(densities)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_overall_density(cls) -> float:
|
||||
"""获取所有层的平均 density"""
|
||||
"""
|
||||
获取所有层的总体 compute density。
|
||||
|
||||
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
|
||||
对于直接记录模式 (gpu_only): density = avg(all_density_values)
|
||||
|
||||
注意: 总体 density 不是简单的 avg(per_layer_density),
|
||||
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
|
||||
"""
|
||||
# 优先使用累积模式 (offload)
|
||||
if cls._layer_selected_counts:
|
||||
total_selected = 0
|
||||
total_total = 0
|
||||
for layer_id in cls._layer_selected_counts:
|
||||
total_selected += sum(cls._layer_selected_counts[layer_id])
|
||||
total_total += sum(cls._layer_total_counts.get(layer_id, []))
|
||||
if total_total > 0:
|
||||
return total_selected / total_total
|
||||
return 0.0
|
||||
|
||||
# 直接记录模式 (gpu_only)
|
||||
all_densities = []
|
||||
for densities in cls._layer_densities.values():
|
||||
all_densities.extend(densities)
|
||||
@@ -130,6 +256,16 @@ class DensityObserver(Observer):
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_overall_comm_density(cls) -> float:
|
||||
"""获取所有层的平均 communication density"""
|
||||
all_densities = []
|
||||
for densities in cls._layer_comm_densities.values():
|
||||
all_densities.extend(densities)
|
||||
if not all_densities:
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_summary(cls) -> dict:
|
||||
"""返回统计摘要"""
|
||||
@@ -160,8 +296,13 @@ class DensityObserver(Observer):
|
||||
per_layer = cls.get_per_layer_density()
|
||||
overall = cls.get_overall_density()
|
||||
min_layer, min_density = cls.get_min_density()
|
||||
overall_comm = cls.get_overall_comm_density()
|
||||
|
||||
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||
print(f" Overall density: {overall:.4f}")
|
||||
print(f" Min density: {min_density:.4f} (layer {min_layer})")
|
||||
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
|
||||
if overall_comm > 0:
|
||||
print(f" Comm density: {overall_comm:.4f}")
|
||||
print(f" Num layers: {len(per_layer)}")
|
||||
# 输出 layer 0 的 density 用于对比
|
||||
if 0 in per_layer:
|
||||
print(f" Layer 0 density: {per_layer[0]:.6f}")
|
||||
|
||||
Reference in New Issue
Block a user