docs: reorganize documentation files
- Move notes.md to docs/development_notes.md - Move Xattention_analysis.md to docs/xattention_analysis.md - Delete DEBUG_SUMMARY.md (no longer needed) - Update CLAUDE.md with documentation index entries Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -61,6 +61,8 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
||||
| [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals |
|
||||
| [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark |
|
||||
| [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md) | 64k inference memory analysis: GPU-only vs offload, OOM root cause (fragmentation), RTX 3090 limitations |
|
||||
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
|
||||
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
|
||||
|
||||
## Configuration
|
||||
|
||||
|
||||
103
DEBUG_SUMMARY.md
103
DEBUG_SUMMARY.md
@@ -1,103 +0,0 @@
|
||||
# Chunked Prefill Bug Debug Summary
|
||||
|
||||
## Problem
|
||||
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
||||
|
||||
The model generates completely wrong tokens instead of the expected "7492".
|
||||
|
||||
## Investigation Progress
|
||||
|
||||
### 1. Stream Synchronization Fix (Completed)
|
||||
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
||||
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
||||
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
||||
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
||||
|
||||
### 2. KV Cache Alignment Verification (Completed)
|
||||
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
||||
|
||||
**RoPE Alignment:**
|
||||
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
||||
- Confirmed RoPE is NOT the cause of the bug
|
||||
|
||||
**K/V Cache Alignment (Chunk 0):**
|
||||
- Cosine similarity: ~1.0 for all layers
|
||||
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
||||
- Mean diff: < 0.001
|
||||
- **Conclusion: K/V cache offload is working correctly**
|
||||
|
||||
### 3. Layer Output Divergence Analysis (Completed)
|
||||
Created per-chunk layer output comparison:
|
||||
|
||||
**Chunk 0 (tokens 0-4096):**
|
||||
- All layers pass with excellent cosine similarity (0.999+)
|
||||
- Max diff grows in later layers but within acceptable range
|
||||
|
||||
**Chunk 1 (tokens 4096-8192):**
|
||||
- Layers 0-19: OK (cosine ~1.0)
|
||||
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
||||
- Divergence correlates with later transformer layers
|
||||
|
||||
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
||||
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
||||
|
||||
```
|
||||
# Without offload: PASSES
|
||||
python tests/test_needle.py --input-len 2048
|
||||
# Output: "7492" (correct)
|
||||
|
||||
# With offload: FAILS
|
||||
python tests/test_needle.py --enable-offload --input-len 2048
|
||||
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
||||
```
|
||||
|
||||
**This proves the bug is NOT in:**
|
||||
- Chunked attention logic (merge_attention_outputs)
|
||||
- Multi-chunk KV loading
|
||||
- Ring buffer pipeline
|
||||
|
||||
**The bug IS in:**
|
||||
- The decode path when CPU offload is enabled
|
||||
- How prefilled KV is loaded/used during decode
|
||||
|
||||
### 5. Decode Path Analysis (In Progress)
|
||||
The decode path in CPU offload mode:
|
||||
1. Prefill writes KV to GPU, offloads to CPU
|
||||
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
||||
3. Attend to prefilled KV + accumulated decode tokens
|
||||
4. Merge results
|
||||
|
||||
**Observations:**
|
||||
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
||||
- CPU cache has valid data (reasonable mean/std values)
|
||||
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
||||
|
||||
## Current Status
|
||||
|
||||
### Working
|
||||
- Stream synchronization fixes
|
||||
- K/V cache offload to CPU (verified alignment)
|
||||
- RoPE implementation
|
||||
- Chunked prefill attention for first chunk
|
||||
|
||||
### Not Working
|
||||
- Decode with CPU offload (even for single-chunk inputs)
|
||||
- Multi-chunk attention (divergence in later layers for chunk 1)
|
||||
|
||||
## Next Steps
|
||||
1. Debug why `prefilled_blocks` is empty after decode
|
||||
2. Check if decode path correctly loads KV from CPU
|
||||
3. Verify decode buffer is being written correctly
|
||||
4. Compare decode attention outputs between offload and non-offload modes
|
||||
|
||||
## Key Files
|
||||
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
||||
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
||||
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
||||
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
||||
|
||||
## Hypothesis
|
||||
The decode path fails because:
|
||||
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
||||
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
||||
3. OR there's a stream synchronization issue specific to decode path
|
||||
597
docs/xattention_analysis.md
Normal file
597
docs/xattention_analysis.md
Normal file
@@ -0,0 +1,597 @@
|
||||
# COMPASS XAttention Implementation Analysis
|
||||
|
||||
**Analysis Date**: 2026-01-14
|
||||
**Researcher**: Claude Code Agent
|
||||
**Source**: `/home/zijie/Code/COMPASS/compass/src/`
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
COMPASS XAttention is a **block sparse attention** implementation that uses:
|
||||
1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks
|
||||
2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func`
|
||||
3. **Triton kernels** for efficient block-wise GEMM and softmax operations
|
||||
|
||||
**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately.
|
||||
|
||||
---
|
||||
|
||||
## 1. Function: `xattn_estimate()`
|
||||
|
||||
**Purpose**: Estimate attention importance and select which blocks to compute
|
||||
|
||||
### Input Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` |
|
||||
| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` |
|
||||
| `block_size` | int | - | Size of attention blocks (typically 128) |
|
||||
| `stride` | int | - | Downsampling stride for approximation |
|
||||
| `norm` | float | 1 | Normalization factor for attention scaling |
|
||||
| `softmax` | bool | True | Whether to apply softmax in estimation |
|
||||
| `threshold` | float | 0.9 | Block selection threshold (0-1) |
|
||||
| `chunk_size` | int | 16384 | Processing chunk size |
|
||||
| `select_mode` | str | "inverse" | Pattern selection mode |
|
||||
| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) |
|
||||
| `causal` | bool | True | Apply causal masking |
|
||||
| `kdb` | int | 1 | Key downsampling factor |
|
||||
| `keep_sink` | bool | False | Always attend to first token |
|
||||
| `keep_recent` | bool | False | Always attend to recent tokens |
|
||||
|
||||
### Output
|
||||
|
||||
```python
|
||||
returns: (attn_sums, simple_masks)
|
||||
attn_sums: Tensor[float32]
|
||||
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
|
||||
Contains aggregated attention weights per block
|
||||
|
||||
simple_masks: Tensor[bool]
|
||||
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
|
||||
Boolean mask indicating which blocks to compute
|
||||
```
|
||||
|
||||
### Algorithm
|
||||
|
||||
#### Step 1: Padding and Chunking
|
||||
```python
|
||||
# Pad sequences to chunk_size boundaries
|
||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||
|
||||
# Compute number of blocks and chunks
|
||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||
```
|
||||
|
||||
#### Step 2: Pattern Selection (stride-based downsampling)
|
||||
|
||||
**Purpose**: Reduce computation by `stride` factor using patterned selection
|
||||
|
||||
**Modes**:
|
||||
1. **`"inverse"`** (default): Inverse stride pattern
|
||||
```python
|
||||
# Key: regular stride [0, stride, 2*stride, ...]
|
||||
# Query: reverse stride [(stride-1), (stride-1-stride), ...]
|
||||
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
||||
reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
|
||||
```
|
||||
|
||||
2. **`"slash"`**: Slash pattern (diagonal)
|
||||
```python
|
||||
# Both use regular stride
|
||||
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
||||
reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
|
||||
```
|
||||
|
||||
3. **`"random"`**: Random permutation
|
||||
4. **`"double"`, `"triple"`**: Data augmentation modes
|
||||
|
||||
#### Step 3: Chunk-wise Attention Estimation
|
||||
|
||||
For each query chunk:
|
||||
|
||||
**If `use_triton=True`** (fast path):
|
||||
```python
|
||||
# Triton kernel 1: Compute attention scores with fused reshape
|
||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||
query_chunk, key_states, stride,
|
||||
chunk_start, chunk_end, is_causal=causal
|
||||
)
|
||||
|
||||
# Triton kernel 2: Softmax + block aggregation
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights_slice, reshaped_block_size, segment_size,
|
||||
chunk_start, chunk_end, real_q_len, scale, is_causal
|
||||
)
|
||||
```
|
||||
|
||||
**If `use_triton=False`** (PyTorch fallback):
|
||||
```python
|
||||
# Standard matrix multiplication
|
||||
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
|
||||
|
||||
# Scale and apply causal mask
|
||||
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
|
||||
attn_weights_slice = attn_weights_slice + causal_mask
|
||||
|
||||
# Softmax
|
||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
|
||||
|
||||
# Aggregate to block level
|
||||
attn_sum = attn_weights_slice.view(
|
||||
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
|
||||
).sum(dim=-1).sum(dim=-2)
|
||||
```
|
||||
|
||||
#### Step 4: Block Selection
|
||||
|
||||
```python
|
||||
# Select blocks based on threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
current_index, # Starting block index
|
||||
threshold, # 0.9 = select blocks covering 90% of attention mass
|
||||
None, # or num_to_choose for top-k selection
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Selection Algorithm** (`find_blocks_chunked`):
|
||||
1. Sort blocks by attention weight (descending)
|
||||
2. Compute cumulative sum
|
||||
3. Select blocks until `cumulative_sum >= total_sum * threshold`
|
||||
4. Enforce causal constraints (no future blocks)
|
||||
5. Always include sink token (first block) if `keep_sink=True`
|
||||
6. Always include diagonal blocks if `keep_recent=True`
|
||||
|
||||
---
|
||||
|
||||
## 2. Function: `Xattention_prefill()`
|
||||
|
||||
**Purpose**: Compute sparse attention using estimated block mask
|
||||
|
||||
### Input Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` |
|
||||
| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
||||
| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
||||
| `stride` | int | - | Downsampling stride for estimation |
|
||||
| `norm` | float | 1 | Normalization factor |
|
||||
| `threshold` | float | 0.8 | Block selection threshold |
|
||||
| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) |
|
||||
| `use_triton` | bool | True | Use Triton kernels in estimation |
|
||||
| `causal` | bool | True | Apply causal masking |
|
||||
| `kdb` | int | 1 | Key downsampling factor |
|
||||
| `chunk_size` | int | None | Auto-computed if None |
|
||||
| `keep_sink` | bool | False | Always attend to first token |
|
||||
| `keep_recent` | bool | False | Always attend to recent tokens |
|
||||
|
||||
### Output
|
||||
|
||||
```python
|
||||
returns: attn_output
|
||||
attn_output: Tensor
|
||||
Shape: (batch, num_heads, q_len, head_dim)
|
||||
Sparse attention output
|
||||
```
|
||||
|
||||
### Algorithm Flow
|
||||
|
||||
#### Step 1: Auto-compute chunk_size
|
||||
```python
|
||||
if chunk_size is None:
|
||||
chunk_size = int(max(
|
||||
min(
|
||||
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
|
||||
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
|
||||
),
|
||||
2048, # Minimum
|
||||
))
|
||||
```
|
||||
|
||||
**Example**:
|
||||
- `k_len=8192` → `chunk_size=8192`
|
||||
- `k_len=32768` → `chunk_size=16384`
|
||||
- `k_len=65536` → `chunk_size=16384`
|
||||
|
||||
#### Step 2: Estimate attention and select blocks
|
||||
```python
|
||||
attn_sums, approx_simple_mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
block_size=block_size, stride=stride, norm=norm,
|
||||
threshold=threshold, select_mode="inverse",
|
||||
use_triton=use_triton, causal=causal,
|
||||
chunk_size=chunk_size, kdb=kdb,
|
||||
keep_sink=keep_sink, keep_recent=keep_recent
|
||||
)
|
||||
```
|
||||
|
||||
#### Step 3: Prepare inputs for block_sparse_attn_func
|
||||
```python
|
||||
# Hard constraints
|
||||
assert block_size == 128
|
||||
assert batch_size == 1
|
||||
|
||||
# Reshape to (seq_len, num_heads, head_dim)
|
||||
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
|
||||
# Cumulative sequence lengths
|
||||
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
||||
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
||||
|
||||
# Head mask type (all heads use mask)
|
||||
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
|
||||
```
|
||||
|
||||
#### Step 4: Call block_sparse_attn_func
|
||||
```python
|
||||
attn_output = block_sparse_attn_func(
|
||||
query_states, # (q_len, num_heads, head_dim)
|
||||
key_states, # (k_len, num_heads, head_dim)
|
||||
value_states, # (k_len, num_heads, head_dim)
|
||||
q_cu_seq_lens, # [0, q_len]
|
||||
k_cu_seq_lens, # [0, k_len]
|
||||
head_mask_type, # [1, 1, ..., 1]
|
||||
None, # No custom layout
|
||||
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
|
||||
q_len,
|
||||
k_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
is_causal=causal
|
||||
)
|
||||
```
|
||||
|
||||
#### Step 5: Reshape output
|
||||
```python
|
||||
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
# Output shape: (batch, num_heads, q_len, head_dim)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Triton Kernel Dependencies
|
||||
|
||||
### Kernel 1: `flat_group_gemm_fuse_reshape_kernel`
|
||||
|
||||
**Purpose**: Compute QK^T with stride-based reshaping
|
||||
|
||||
**Key Features**:
|
||||
- Loads `stride` keys and queries at once
|
||||
- Fused strided access pattern
|
||||
- Causal masking support
|
||||
- Block size auto-selection based on GPU memory
|
||||
|
||||
**Block Size Selection**:
|
||||
```python
|
||||
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
|
||||
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
|
||||
```
|
||||
|
||||
**Signature**:
|
||||
```python
|
||||
flat_group_gemm_fuse_reshape(
|
||||
query_states, # (batch, heads, q_len, head_dim)
|
||||
key_states, # (batch, heads, k_len, head_dim)
|
||||
stride, # Downsampling factor
|
||||
chunk_start, # Start position in keys
|
||||
chunk_end, # End position in keys
|
||||
is_causal=True
|
||||
)
|
||||
# Returns: (batch, heads, q_len//stride, k_len//stride)
|
||||
```
|
||||
|
||||
### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal`
|
||||
|
||||
**Purpose**: Online softmax with block aggregation
|
||||
|
||||
**Algorithm**:
|
||||
1. **Forward pass** (compute m_i, l_i):
|
||||
```
|
||||
m_i = max(m_i, m_local)
|
||||
alpha = exp(m_i - m_new)
|
||||
l_i = l_i * alpha + sum(exp(X - m_new))
|
||||
```
|
||||
2. **Backward pass** (compute softmax with scaling):
|
||||
```
|
||||
softmax = exp(X - m_i) / l_i
|
||||
aggregate to blocks: sum(softmax) over block_size
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
- Single-pass softmax (no materializing full attention matrix)
|
||||
- Causal masking integrated
|
||||
- Outputs block-level sums directly
|
||||
|
||||
**Signature**:
|
||||
```python
|
||||
softmax_fuse_block_sum(
|
||||
attn_weights_slice, # (batch, heads, q_len, k_len)
|
||||
reshaped_block_size, # Block size (128//stride)
|
||||
segment_size, # Processing segment (min(4096, block_size))
|
||||
chunk_start, # Start position
|
||||
chunk_end, # End position
|
||||
real_q_len, # Actual query length (before padding)
|
||||
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
|
||||
is_causal=True
|
||||
)
|
||||
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Key Parameters and Their Meanings
|
||||
|
||||
### Critical Parameters
|
||||
|
||||
| Parameter | Meaning | Typical Value | Impact |
|
||||
|-----------|---------|---------------|--------|
|
||||
| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity |
|
||||
| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate |
|
||||
| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
|
||||
| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency |
|
||||
| `kdb` | Key downsampling boost | 1 | Experimental, use 1 |
|
||||
| `norm` | Scaling factor | 1.0 | Attention temperature control |
|
||||
|
||||
### Trade-offs
|
||||
|
||||
**Stride (`stride`)**:
|
||||
- `stride=1`: No approximation, same as dense attention
|
||||
- `stride=4`: 4x faster estimation, good accuracy
|
||||
- `stride=8`: 8x faster, moderate accuracy loss
|
||||
- `stride=16`: 16x faster, significant accuracy loss
|
||||
|
||||
**Threshold (`threshold`)**:
|
||||
- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity)
|
||||
- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity)
|
||||
- `threshold=0.95`: Very dense, only prunes ~5% of blocks
|
||||
|
||||
---
|
||||
|
||||
## 5. Dependencies
|
||||
|
||||
### Required Libraries
|
||||
|
||||
1. **`block_sparse_attn`** (CRITICAL)
|
||||
- Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/`
|
||||
- Function: `block_sparse_attn_func`
|
||||
- Type: **C++ CUDA extension**
|
||||
- Build: Requires compilation with `torch.utils.cpp_extension`
|
||||
|
||||
2. **Triton** (optional but recommended)
|
||||
- Required for: `use_triton=True`
|
||||
- GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
|
||||
- Check: `torch.cuda.get_device_properties().major >= 8`
|
||||
|
||||
3. **PyTorch**
|
||||
- Version: Compatible with flash-attention
|
||||
- Features: F.pad, matmul, softmax, view, transpose
|
||||
|
||||
### Dependency Tree
|
||||
|
||||
```
|
||||
Xattention_prefill
|
||||
├── xattn_estimate
|
||||
│ ├── flat_group_gemm_fuse_reshape (Triton)
|
||||
│ ├── softmax_fuse_block_sum (Triton)
|
||||
│ └── find_blocks_chunked (PyTorch)
|
||||
└── block_sparse_attn_func (C++ CUDA)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Integration Issues for nano-vllm
|
||||
|
||||
### Critical Issue 1: `block_sparse_attn_func` Dependency
|
||||
|
||||
**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source.
|
||||
|
||||
**Options**:
|
||||
1. **Compile flash-attention with block sparse support**
|
||||
```bash
|
||||
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
|
||||
python setup.py install
|
||||
```
|
||||
- Risk: May conflict with existing flash-attention installation
|
||||
- Complexity: High (C++ compilation)
|
||||
|
||||
2. **Replace with FlashInfer block sparse**
|
||||
- FlashInfer is already a dependency
|
||||
- Has similar block sparse attention
|
||||
- Need to adapt interface
|
||||
|
||||
3. **Custom CUDA kernel**
|
||||
- Implement simplified block sparse attention
|
||||
- High development cost
|
||||
- Maintenance burden
|
||||
|
||||
### Critical Issue 2: Hard-coded Constraints
|
||||
|
||||
```python
|
||||
assert block_size == 128 # Line 358
|
||||
assert batch_size == 1 # Line 359
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- Cannot process multiple sequences in one batch
|
||||
- Fixed block size limits flexibility
|
||||
- Must work around these constraints
|
||||
|
||||
### Critical Issue 3: Triton GPU Requirement
|
||||
|
||||
```python
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.major < 8:
|
||||
use_triton = False
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
|
||||
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
|
||||
- RTX 3090 works but uses smaller block sizes (64 vs 128)
|
||||
|
||||
### Issue 4: Memory Layout
|
||||
|
||||
**XAttention expects**:
|
||||
```python
|
||||
query_states: (batch, num_heads, q_len, head_dim)
|
||||
```
|
||||
|
||||
**nano-vllm uses**:
|
||||
```python
|
||||
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
|
||||
```
|
||||
|
||||
**Required**: Transpose and reshape before/after calling XAttention
|
||||
|
||||
### Issue 5: Chunking Incompatibility
|
||||
|
||||
**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens)
|
||||
- Requires padding to chunk boundaries
|
||||
- Adds overhead for short sequences
|
||||
|
||||
**nano-vllm**: Processes variable-length requests
|
||||
- No padding requirement
|
||||
- Dynamic batch sizing
|
||||
|
||||
---
|
||||
|
||||
## 7. Integration Strategy
|
||||
|
||||
### Recommended Approach: **Wrapper with FlashInfer**
|
||||
|
||||
1. **Keep `xattn_estimate`** (pure PyTorch + Triton)
|
||||
- No external dependencies
|
||||
- Computes block mask
|
||||
|
||||
2. **Replace `block_sparse_attn_func` with FlashInfer**
|
||||
- FlashInfer: `flashinfer.single_prefill_with_kv_cache`
|
||||
- Similar API, already compiled
|
||||
- Supports block sparse
|
||||
|
||||
3. **Adapt mask format**
|
||||
- XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask
|
||||
- FlashInfer: `(num_qo, num_kv)` boolean mask or custom format
|
||||
|
||||
4. **Handle constraints**
|
||||
- Enforce `batch_size=1` by processing one request at a time
|
||||
- Keep `block_size=128` as requirement
|
||||
|
||||
### Alternative: **Pure PyTorch Implementation**
|
||||
|
||||
1. Extract estimation algorithm
|
||||
2. Implement sparse attention using PyTorch operations
|
||||
3. Use FlashInfer for final computation
|
||||
4. No Triton dependency
|
||||
|
||||
---
|
||||
|
||||
## 8. Code Example: Adaptation
|
||||
|
||||
```python
|
||||
def xattention_prefill_adapted(
|
||||
query_states, # (num_heads, q_len, head_dim)
|
||||
key_states, # (num_heads, k_len, head_dim)
|
||||
value_states, # (num_heads, k_len, head_dim)
|
||||
stride=4,
|
||||
threshold=0.9,
|
||||
block_size=128,
|
||||
causal=True,
|
||||
):
|
||||
# Step 1: Add batch dimension
|
||||
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
|
||||
k = key_states.unsqueeze(0)
|
||||
v = value_states.unsqueeze(0)
|
||||
|
||||
# Step 2: Estimate mask (no external dependency)
|
||||
_, block_mask = xattn_estimate(
|
||||
q, k,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
causal=causal,
|
||||
)
|
||||
# block_mask: (1, heads, q_blocks, k_blocks)
|
||||
|
||||
# Step 3: Convert block mask to token mask
|
||||
q_blocks, k_blocks = block_mask.shape[-2:]
|
||||
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
|
||||
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
|
||||
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
|
||||
|
||||
# Step 4: Use FlashInfer with mask
|
||||
from flashinfer import single_prefill_with_kv_cache
|
||||
output = single_prefill_with_kv_cache(
|
||||
q.squeeze(0),
|
||||
k.squeeze(0),
|
||||
v.squeeze(0),
|
||||
custom_mask=token_mask.squeeze(0),
|
||||
)
|
||||
|
||||
return output # (num_heads, q_len, head_dim)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. Summary of Findings
|
||||
|
||||
### Advantages
|
||||
|
||||
1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns
|
||||
2. **Flexible sparsity**: Threshold-based control over computation
|
||||
3. **GPU optimization**: Triton kernels for estimation phase
|
||||
4. **Proven in practice**: Used in COMPASS system
|
||||
|
||||
### Challenges
|
||||
|
||||
1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation
|
||||
2. **Rigid constraints**: `block_size=128`, `batch_size=1`
|
||||
3. **GPU-specific**: Triton only on SM 80+
|
||||
4. **Memory layout mismatch**: Requires reshape/transpose
|
||||
5. **Chunking overhead**: Padding to chunk boundaries
|
||||
|
||||
### Integration Complexity
|
||||
|
||||
| Component | Complexity | Risk |
|
||||
|-----------|------------|------|
|
||||
| `xattn_estimate` | Medium | Low (PyTorch + Triton) |
|
||||
| `block_sparse_attn_func` | High | **Critical** (C++ dependency) |
|
||||
| Interface adaptation | Low | Low (reshape) |
|
||||
| Constraint handling | Medium | Medium (workarounds) |
|
||||
|
||||
**Overall Integration Risk**: **HIGH** (due to C++ dependency)
|
||||
|
||||
---
|
||||
|
||||
## 10. Next Steps
|
||||
|
||||
1. **Evaluate FlashInfer compatibility**
|
||||
- Can FlashInfer replace `block_sparse_attn_func`?
|
||||
- What mask format does it expect?
|
||||
|
||||
2. **Prototype estimation phase**
|
||||
- Extract `xattn_estimate` function
|
||||
- Test with nano-vllm inputs
|
||||
- Validate mask quality
|
||||
|
||||
3. **Benchmark Triton kernels**
|
||||
- Compare Triton vs PyTorch estimation
|
||||
- Measure speedup on RTX 3090
|
||||
- Profile memory usage
|
||||
|
||||
4. **Design interface**
|
||||
- Define nano-vllm sparse attention API
|
||||
- Specify mask format
|
||||
- Plan integration points
|
||||
Reference in New Issue
Block a user