Compare commits
4 Commits
aa953ecb59
...
6575099a06
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6575099a06 | ||
|
|
8fd25d72d7 | ||
|
|
ccf27d3a74 | ||
|
|
0ad86eb449 |
17
CLAUDE.md
17
CLAUDE.md
@@ -46,6 +46,17 @@ python bench_offload.py
|
|||||||
|
|
||||||
## Local Package Installation for Multi-Instance
|
## Local Package Installation for Multi-Instance
|
||||||
|
|
||||||
|
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run with PYTHONPATH:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
|
||||||
|
```
|
||||||
|
|
||||||
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
||||||
|
|
||||||
1. **Install to worktree-local directory**:
|
1. **Install to worktree-local directory**:
|
||||||
@@ -66,6 +77,12 @@ python bench_offload.py
|
|||||||
|
|
||||||
**Note**: The Python version in the path (python3.10) should match your environment.
|
**Note**: The Python version in the path (python3.10) should match your environment.
|
||||||
|
|
||||||
|
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
|
||||||
|
|
||||||
## Sparse Attention
|
## Sparse Attention
|
||||||
|
|
||||||
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||||
|
|||||||
@@ -455,8 +455,6 @@ class ModelRunner:
|
|||||||
3. After each chunk, offload from ring buffer slot to CPU
|
3. After each chunk, offload from ring buffer slot to CPU
|
||||||
4. All N-1 other slots are used to load previous chunks for attention
|
4. All N-1 other slots are used to load previous chunks for attention
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
|
|
||||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
@@ -466,10 +464,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
total_tokens = len(seq)
|
total_tokens = len(seq)
|
||||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||||
f"total_chunks={num_chunks}",
|
f"total_chunks={num_chunks}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
logits = None
|
logits = None
|
||||||
@@ -488,9 +485,8 @@ class ModelRunner:
|
|||||||
# CPU block index for this chunk
|
# CPU block index for this chunk
|
||||||
block_idx = chunk_idx
|
block_idx = chunk_idx
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||||
f"write_slot={write_slot}",
|
f"write_slot={write_slot}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||||
@@ -509,27 +505,17 @@ class ModelRunner:
|
|||||||
logical_id = seq.block_table[block_idx]
|
logical_id = seq.block_table[block_idx]
|
||||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||||
|
|
||||||
# NOTE: Per-layer offloading is now done in attention.forward
|
# NOTE: Per-layer async offloading is now done in attention.forward
|
||||||
# Each layer offloads its KV to CPU immediately after computing attention.
|
# Each layer offloads from its own prefill buffer - no waiting required!
|
||||||
# We just need to wait for the last offload to complete before reusing the slot.
|
# The sparse policy hook is called in offload_prefill_buffer_async.
|
||||||
if block_idx < len(cpu_block_ids):
|
|
||||||
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
|
||||||
# The GPU cache no longer has layer dimension, so we can't access
|
|
||||||
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
|
||||||
# in attention.forward after per-layer offload.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Wait for offload to complete before next chunk
|
|
||||||
# (slot will be reused after N chunks)
|
|
||||||
offload_engine.wait_slot_offload(write_slot)
|
|
||||||
|
|
||||||
processed_tokens = chunk_end
|
processed_tokens = chunk_end
|
||||||
chunk_idx += 1
|
chunk_idx += 1
|
||||||
|
|
||||||
# Wait for all offloads to complete
|
# Wait for all async prefill offloads to complete
|
||||||
offload_engine.wait_all_offload_done()
|
offload_engine.wait_all_prefill_offloads()
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
|
||||||
|
|
||||||
# Sample from last logits
|
# Sample from last logits
|
||||||
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||||
@@ -590,14 +576,15 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run decode with ring buffer (CPU is primary storage).
|
Run decode with cross-layer pipeline (CPU is primary storage).
|
||||||
|
|
||||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
Optimized with cross-layer pipeline: Layer N's data is loaded while
|
||||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
Layer N-1 computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
Optimization: Cross-layer pipeline reduces effective latency by overlapping
|
||||||
|
H2D transfers with attention computation across layers.
|
||||||
"""
|
"""
|
||||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
@@ -618,6 +605,12 @@ class ModelRunner:
|
|||||||
# Get decode start position for accumulated token tracking
|
# Get decode start position for accumulated token tracking
|
||||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
|
||||||
|
# Get prefilled CPU blocks for pipeline initialization
|
||||||
|
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Start cross-layer pipeline (preloads Layer 0's data)
|
||||||
|
offload_engine.start_decode_pipeline(cpu_block_table)
|
||||||
|
|
||||||
# Set up context for chunked decode
|
# Set up context for chunked decode
|
||||||
set_context(
|
set_context(
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
@@ -634,6 +627,9 @@ class ModelRunner:
|
|||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
|
# End cross-layer pipeline
|
||||||
|
offload_engine.end_decode_pipeline()
|
||||||
|
|
||||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||||
# This avoids unnecessary offloading on every decode step
|
# This avoids unnecessary offloading on every decode step
|
||||||
if pos_in_block == self.block_size - 1:
|
if pos_in_block == self.block_size - 1:
|
||||||
|
|||||||
@@ -40,14 +40,13 @@ class OffloadEngine:
|
|||||||
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
||||||
|
|
||||||
Memory layout:
|
Memory layout:
|
||||||
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
- GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dimension)
|
||||||
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
||||||
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
|
|
||||||
|
|
||||||
CUDA Graph compatibility:
|
Features:
|
||||||
- gathered_h2d_layer() can be captured into CUDA graphs
|
- Unified ring buffer for chunked prefill/decode
|
||||||
- update_gather_indices() is called outside graphs to prepare indices
|
- Per-layer prefill buffer for async offload
|
||||||
- All tensor addresses remain fixed across graph replays
|
- Cross-layer pipeline for decode with double-buffering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -142,6 +141,64 @@ class OffloadEngine:
|
|||||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# ========== Cross-layer pipeline buffers for decode ==========
|
||||||
|
# Double-buffered layer cache for pipelined decode:
|
||||||
|
# - Buffer A: Current layer's prefilled KV being computed
|
||||||
|
# - Buffer B: Next layer's prefilled KV being loaded
|
||||||
|
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
|
||||||
|
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
|
||||||
|
self.layer_k_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_k_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
|
||||||
|
|
||||||
|
# Pipeline state tracking
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
|
||||||
|
self._pipeline_next_layer_event = torch.cuda.Event()
|
||||||
|
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
||||||
|
|
||||||
|
# ========== Per-layer prefill buffer for async offload ==========
|
||||||
|
# During chunked prefill, all layers share the same GPU slot. This means
|
||||||
|
# each layer must wait for offload to complete before the next layer can
|
||||||
|
# write to the same slot. This serializes offloads and hurts performance.
|
||||||
|
# Solution: Maintain separate per-layer buffers for prefill.
|
||||||
|
# Each layer writes to its own buffer, enabling fully async offloads.
|
||||||
|
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||||
|
self.prefill_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.prefill_v_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Per-layer offload events for async prefill offload
|
||||||
|
# Each layer has its own event to track offload completion
|
||||||
|
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
# Per-layer transfer streams for parallel offloads
|
||||||
|
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
|
||||||
|
|
||||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||||
self.k_cache_cpu = torch.zeros(
|
self.k_cache_cpu = torch.zeros(
|
||||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
@@ -152,19 +209,6 @@ class OffloadEngine:
|
|||||||
dtype=dtype, device="cpu", pin_memory=True
|
dtype=dtype, device="cpu", pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== Fixed-address gather indices (content is variable) ==========
|
|
||||||
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
|
|
||||||
# -1 means no-op (skip this slot)
|
|
||||||
self.gather_indices_cpu = torch.empty(
|
|
||||||
num_layers, num_gpu_blocks,
|
|
||||||
dtype=torch.int64, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu = torch.full(
|
|
||||||
(num_layers, num_gpu_blocks), -1,
|
|
||||||
dtype=torch.int64, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log memory allocation
|
# Log memory allocation
|
||||||
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
||||||
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
||||||
@@ -219,321 +263,6 @@ class OffloadEngine:
|
|||||||
# ========== Sparse attention policy (set at construction time) ==========
|
# ========== Sparse attention policy (set at construction time) ==========
|
||||||
self.sparse_policy = sparse_policy
|
self.sparse_policy = sparse_policy
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
|
||||||
stream = self.transfer_streams[self._stream_idx]
|
|
||||||
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
|
|
||||||
return stream
|
|
||||||
|
|
||||||
# ========== CUDA Graph compatible methods ==========
|
|
||||||
# NOTE: These methods need to be updated for the new GPU cache architecture.
|
|
||||||
# GPU cache no longer has layer dimension, so gathered copy semantics change.
|
|
||||||
# For now, these are kept for reference but should not be used without updating.
|
|
||||||
|
|
||||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for a single layer.
|
|
||||||
|
|
||||||
WARNING: This method needs updating for new GPU cache architecture.
|
|
||||||
GPU cache no longer has layer dimension.
|
|
||||||
"""
|
|
||||||
# GPU cache has no layer dimension - use flat indexing
|
|
||||||
# Source is CPU[layer_id], dest is GPU (shared across layers)
|
|
||||||
gathered_copy_kv(
|
|
||||||
k_src=self.k_cache_cpu[layer_id],
|
|
||||||
v_src=self.v_cache_cpu[layer_id],
|
|
||||||
k_dst=self.k_cache_gpu, # No layer indexing
|
|
||||||
v_dst=self.v_cache_gpu, # No layer indexing
|
|
||||||
indices=self.gather_indices_gpu[layer_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
def gathered_h2d_all_layers(self) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for all layers.
|
|
||||||
|
|
||||||
WARNING: In new architecture, GPU slots are shared across layers.
|
|
||||||
This method would overwrite slots multiple times. Not recommended.
|
|
||||||
"""
|
|
||||||
for layer_id in range(self.num_layers):
|
|
||||||
self.gathered_h2d_layer(layer_id)
|
|
||||||
|
|
||||||
def update_gather_indices(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
mappings: List[Tuple[int, int]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for a layer (call OUTSIDE CUDA graph).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index
|
|
||||||
mappings: List of (cpu_block_id, gpu_slot) tuples
|
|
||||||
Only these slots will be updated; others keep their values
|
|
||||||
"""
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Async copy to GPU
|
|
||||||
self.gather_indices_gpu[layer_id].copy_(
|
|
||||||
self.gather_indices_cpu[layer_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_gather_indices_all_layers(
|
|
||||||
self,
|
|
||||||
mappings_per_layer: List[List[Tuple[int, int]]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for all layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
|
|
||||||
"""
|
|
||||||
for layer_id, mappings in enumerate(mappings_per_layer):
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Batch copy all layers
|
|
||||||
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
|
|
||||||
|
|
||||||
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
|
|
||||||
"""
|
|
||||||
Clear gather indices (set all to -1, meaning no-op).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: If provided, clear only this layer; otherwise clear all
|
|
||||||
"""
|
|
||||||
if layer_id is not None:
|
|
||||||
self.gather_indices_cpu[layer_id].fill_(-1)
|
|
||||||
self.gather_indices_gpu[layer_id].fill_(-1)
|
|
||||||
else:
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu.fill_(-1)
|
|
||||||
|
|
||||||
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
|
|
||||||
|
|
||||||
def prefetch_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async prefetch a single block from CPU to GPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_id: Source block in CPU cache
|
|
||||||
gpu_block_id: Destination slot in GPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
self.pending_events[(layer_id, gpu_block_id)] = event
|
|
||||||
return event
|
|
||||||
|
|
||||||
def prefetch_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async prefetch multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events for each transfer
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, cpu_block_id, gpu_block_id in transfers:
|
|
||||||
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
def offload_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async offload a block from GPU to CPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
gpu_block_id: Source slot in GPU cache
|
|
||||||
cpu_block_id: Destination block in CPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# Wait for any compute using this block
|
|
||||||
stream.wait_stream(self.compute_stream)
|
|
||||||
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.k_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.v_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
def offload_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async offload multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, gpu_block_id, cpu_block_id in transfers:
|
|
||||||
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Load CPU blocks to specific GPU slots for chunked decode.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for transfer to complete
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async version: Load CPU blocks to GPU slots.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event to wait on
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
|
|
||||||
# layer dimension. Each GPU slot holds data for ONE layer at a time.
|
|
||||||
|
|
||||||
# ========== Synchronization methods ==========
|
|
||||||
|
|
||||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
|
||||||
"""Wait for a specific block's transfer to complete."""
|
|
||||||
key = (layer_id, gpu_block_id)
|
|
||||||
if key in self.pending_events:
|
|
||||||
self.pending_events[key].synchronize()
|
|
||||||
del self.pending_events[key]
|
|
||||||
|
|
||||||
def wait_all_transfers(self) -> None:
|
|
||||||
"""Wait for all pending transfers to complete."""
|
|
||||||
for stream in self.transfer_streams:
|
|
||||||
stream.synchronize()
|
|
||||||
self.pending_events.clear()
|
|
||||||
|
|
||||||
def sync_indices(self) -> None:
|
|
||||||
"""Synchronize to ensure all index updates are complete."""
|
|
||||||
torch.cuda.default_stream().synchronize()
|
|
||||||
|
|
||||||
# ========== Cache access methods ==========
|
# ========== Cache access methods ==========
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
@@ -547,54 +276,22 @@ class OffloadEngine:
|
|||||||
(k_cache, v_cache) tensors
|
(k_cache, v_cache) tensors
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# GPU cache is shared across all layers (no layer dimension)
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
return self.k_cache_gpu, self.v_cache_gpu
|
||||||
|
|
||||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get full GPU K/V cache tensors.
|
|
||||||
|
|
||||||
NOTE: GPU cache has no layer dimension in the new architecture.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) tensors
|
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
|
||||||
|
|
||||||
def get_cpu_block(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get a specific CPU block's K/V cache.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) for the block
|
|
||||||
Shape: [block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========== Memory info ==========
|
# ========== Memory info ==========
|
||||||
|
|
||||||
def gpu_memory_bytes(self) -> int:
|
def gpu_memory_bytes(self) -> int:
|
||||||
"""Total GPU memory used by KV caches."""
|
"""Total GPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
||||||
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
|
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size()
|
||||||
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def cpu_memory_bytes(self) -> int:
|
def cpu_memory_bytes(self) -> int:
|
||||||
"""Total CPU memory used by KV caches."""
|
"""Total CPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
||||||
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
|
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size()
|
||||||
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -897,102 +594,6 @@ class OffloadEngine:
|
|||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
|
||||||
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
|
|
||||||
|
|
||||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses first half of decode_load_slots as 'compute' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half]
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_compute_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'compute' region loading."""
|
|
||||||
if self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses second half of decode_load_slots as 'prefetch' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots # Fallback if only 1-2 slots
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_prefetch_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'prefetch' region loading."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if slots:
|
|
||||||
self.wait_slot_layer(slots[0])
|
|
||||||
elif self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def get_kv_for_compute(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half][:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
def get_kv_for_prefetch(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots
|
|
||||||
slots = slots[:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
# ========== Debug Hook Interface ==========
|
# ========== Debug Hook Interface ==========
|
||||||
#
|
#
|
||||||
# Minimal generic hook system for debugging.
|
# Minimal generic hook system for debugging.
|
||||||
@@ -1064,3 +665,207 @@ class OffloadEngine:
|
|||||||
if e.__class__.__name__ == 'BdbQuit':
|
if e.__class__.__name__ == 'BdbQuit':
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
|
|
||||||
|
# ========== Cross-layer Pipeline Methods for Decode ==========
|
||||||
|
|
||||||
|
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Start cross-layer pipeline for decode.
|
||||||
|
|
||||||
|
Called at the beginning of a decode step to initialize the pipeline.
|
||||||
|
Preloads Layer 0's data into buffer A.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_ids: List of CPU block IDs for prefilled blocks
|
||||||
|
"""
|
||||||
|
if not cpu_block_ids:
|
||||||
|
self._pipeline_active = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pipeline_active = True
|
||||||
|
self._pipeline_cpu_blocks = cpu_block_ids
|
||||||
|
self._pipeline_num_blocks = len(cpu_block_ids)
|
||||||
|
self._pipeline_current_buffer = 0
|
||||||
|
|
||||||
|
# Preload Layer 0 into buffer A
|
||||||
|
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
|
||||||
|
|
||||||
|
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get KV cache for a layer during decode.
|
||||||
|
|
||||||
|
If pipeline is active, returns data from the current buffer.
|
||||||
|
Also triggers preloading of the next layer (if not last layer).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Current layer ID
|
||||||
|
num_blocks: Number of blocks to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
if not self._pipeline_active:
|
||||||
|
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
|
||||||
|
|
||||||
|
# Wait for current layer's data to be ready
|
||||||
|
self.compute_stream.wait_event(self._pipeline_next_layer_event)
|
||||||
|
|
||||||
|
# Get current buffer
|
||||||
|
if self._pipeline_current_buffer == 0:
|
||||||
|
k = self.layer_k_buffer_a[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_a[:num_blocks]
|
||||||
|
else:
|
||||||
|
k = self.layer_k_buffer_b[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_b[:num_blocks]
|
||||||
|
|
||||||
|
# Trigger preloading of next layer (if not last layer)
|
||||||
|
next_layer_id = layer_id + 1
|
||||||
|
if next_layer_id < self.num_layers:
|
||||||
|
# Use the other buffer for next layer
|
||||||
|
next_buffer_idx = 1 - self._pipeline_current_buffer
|
||||||
|
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
|
||||||
|
# Switch to next buffer for next layer
|
||||||
|
self._pipeline_current_buffer = next_buffer_idx
|
||||||
|
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Async load a layer's prefilled blocks to the specified buffer.
|
||||||
|
|
||||||
|
Uses sgDMA for efficient strided transfer from CPU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index to load
|
||||||
|
buffer_idx: 0 for buffer A, 1 for buffer B
|
||||||
|
"""
|
||||||
|
num_blocks = self._pipeline_num_blocks
|
||||||
|
cpu_block_ids = self._pipeline_cpu_blocks
|
||||||
|
|
||||||
|
# Select target buffer
|
||||||
|
if buffer_idx == 0:
|
||||||
|
k_buffer = self.layer_k_buffer_a
|
||||||
|
v_buffer = self.layer_v_buffer_a
|
||||||
|
else:
|
||||||
|
k_buffer = self.layer_k_buffer_b
|
||||||
|
v_buffer = self.layer_v_buffer_b
|
||||||
|
|
||||||
|
# Load all blocks for this layer using dedicated stream
|
||||||
|
with torch.cuda.stream(self._pipeline_layer_stream):
|
||||||
|
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||||
|
# Copy from CPU cache (has layer dimension) to GPU buffer
|
||||||
|
k_buffer[i].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
v_buffer[i].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
# Record event when all transfers complete
|
||||||
|
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
|
||||||
|
|
||||||
|
def end_decode_pipeline(self) -> None:
|
||||||
|
"""
|
||||||
|
End the cross-layer pipeline.
|
||||||
|
|
||||||
|
Called at the end of a decode step to clean up pipeline state.
|
||||||
|
"""
|
||||||
|
if self._pipeline_active:
|
||||||
|
# Ensure all transfers complete before ending
|
||||||
|
self._pipeline_layer_stream.synchronize()
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_cpu_blocks = []
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
|
||||||
|
def is_pipeline_active(self) -> bool:
|
||||||
|
"""Check if decode pipeline is currently active."""
|
||||||
|
return self._pipeline_active
|
||||||
|
|
||||||
|
# ========== Per-layer Prefill Buffer Methods ==========
|
||||||
|
# These methods enable async offload during chunked prefill by using
|
||||||
|
# per-layer buffers instead of shared GPU slots.
|
||||||
|
|
||||||
|
def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get prefill buffer for a layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_prefill_buffer_slice(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get a slice of prefill buffer for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
num_tokens: Number of valid tokens in current chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k, v) with shape [1, num_tokens, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def offload_prefill_buffer_async(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_id: int,
|
||||||
|
num_valid_tokens: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Async offload prefill buffer to CPU (no waiting required).
|
||||||
|
|
||||||
|
This uses per-layer streams and events to enable fully async offloads.
|
||||||
|
Each layer can offload independently without blocking other layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
cpu_block_id: Target CPU block ID
|
||||||
|
num_valid_tokens: Number of valid tokens (-1 = use block_size)
|
||||||
|
"""
|
||||||
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
|
|
||||||
|
# Collect sparse policy metadata before offload
|
||||||
|
if self.sparse_policy is not None:
|
||||||
|
k_cache = self.prefill_k_buffer[layer_id]
|
||||||
|
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
|
# Use per-layer stream for parallel offloads
|
||||||
|
stream = self.prefill_offload_streams[layer_id]
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# Wait for compute to finish writing to prefill buffer
|
||||||
|
stream.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
|
# Copy from prefill buffer to CPU
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_k_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_v_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record completion event
|
||||||
|
self.prefill_offload_events[layer_id].record(stream)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
def wait_all_prefill_offloads(self) -> None:
|
||||||
|
"""Wait for all prefill buffer offloads to complete."""
|
||||||
|
for stream in self.prefill_offload_streams:
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
def wait_prefill_offload(self, layer_id: int) -> None:
|
||||||
|
"""Wait for a specific layer's prefill offload to complete."""
|
||||||
|
self.prefill_offload_events[layer_id].synchronize()
|
||||||
|
|||||||
@@ -99,8 +99,23 @@ class Attention(nn.Module):
|
|||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
#! =======================================================
|
#! =======================================================
|
||||||
|
|
||||||
if is_chunked_offload:
|
if is_chunked_offload and context.is_prefill:
|
||||||
# Chunked offload mode: use compute_stream for store_kvcache
|
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||||
|
# This enables fully async offloads since each layer has its own buffer.
|
||||||
|
offload_engine = context.kvcache_manager.offload_engine
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||||
|
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||||
|
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||||
|
elif is_chunked_offload:
|
||||||
|
# Chunked decode mode: use compute_stream for store_kvcache
|
||||||
# This ensures proper synchronization with per-layer offload
|
# This ensures proper synchronization with per-layer offload
|
||||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
if k_cache.numel() and v_cache.numel():
|
if k_cache.numel() and v_cache.numel():
|
||||||
@@ -157,36 +172,36 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute attention with unified ring buffer for chunked prefill.
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
Ring buffer design:
|
Optimized design:
|
||||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
- Each layer offloads from its own buffer - no waiting required!
|
||||||
|
|
||||||
For each layer:
|
For each layer:
|
||||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||||
2. Load previous chunks from CPU using available slots (pipeline)
|
2. Load previous chunks from CPU using available slots (pipeline)
|
||||||
3. Compute attention against previous KV (no causal mask)
|
3. Compute attention against previous KV (no causal mask)
|
||||||
4. Compute attention against current KV (causal)
|
4. Compute attention against current KV from prefill buffer (causal)
|
||||||
5. Merge all results using online softmax
|
5. Merge all results using online softmax
|
||||||
|
6. Async offload prefill buffer to CPU (no waiting!)
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
current_chunk_idx = context.current_chunk_idx
|
current_chunk_idx = context.current_chunk_idx
|
||||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||||
|
|
||||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
# q shape: [total_tokens, num_heads, head_dim]
|
||||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||||
k_batched = k.unsqueeze(0)
|
num_tokens = k.shape[0]
|
||||||
v_batched = v.unsqueeze(0)
|
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
|
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||||
|
|
||||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
@@ -210,11 +225,8 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||||
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
# Get write slot for current chunk and available load slots
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
pipeline_depth = len(load_slots)
|
pipeline_depth = len(load_slots)
|
||||||
|
|
||||||
if pipeline_depth == 0:
|
if pipeline_depth == 0:
|
||||||
@@ -230,15 +242,14 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get compute stream for all attention operations
|
# Get compute stream for all attention operations
|
||||||
compute_stream = None
|
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
|
||||||
compute_stream = kvcache_manager.offload_engine.compute_stream
|
|
||||||
|
|
||||||
# Compute attention against current chunk's KV (with causal mask)
|
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||||
# Use compute_stream to ensure proper sync with store_kvcache and offload
|
|
||||||
if compute_stream is not None:
|
if compute_stream is not None:
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
# Get KV from per-layer prefill buffer
|
||||||
|
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
q_batched,
|
q_batched,
|
||||||
k_batched,
|
k_batched,
|
||||||
@@ -249,6 +260,8 @@ class Attention(nn.Module):
|
|||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
else:
|
else:
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
k_batched = k.unsqueeze(0)
|
||||||
|
v_batched = v.unsqueeze(0)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
q_batched,
|
q_batched,
|
||||||
k_batched,
|
k_batched,
|
||||||
@@ -274,27 +287,17 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
# No waiting required! Each layer has its own buffer and stream.
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
if offload_engine is not None and seq is not None:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
|
||||||
if seq is not None:
|
|
||||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
if current_chunk_idx < len(cpu_block_ids):
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
# k.shape[0] = number of tokens in current chunk
|
# Async offload - no waiting, fully parallel across layers
|
||||||
num_valid_tokens = k.shape[0]
|
offload_engine.offload_prefill_buffer_async(
|
||||||
offload_engine.offload_slot_layer_to_cpu(
|
self.layer_id, cpu_block_id, num_tokens
|
||||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# CRITICAL: compute_stream must wait for offload to complete
|
|
||||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
|
||||||
# Without this, Layer N+1's store races with Layer N's offload copy.
|
|
||||||
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
|
|
||||||
|
|
||||||
# Sync default stream with compute_stream before returning
|
# Sync default stream with compute_stream before returning
|
||||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||||
if compute_stream is not None:
|
if compute_stream is not None:
|
||||||
@@ -479,17 +482,15 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention using ring buffer pipeline (same as prefill).
|
Compute decode attention using cross-layer pipeline.
|
||||||
|
|
||||||
Uses the same loading mechanism as _chunked_prefill_attention:
|
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||||
- Load one block at a time from CPU to GPU slot
|
with computation across layers:
|
||||||
- Compute attention for each block
|
- Layer N computes while Layer N+1's data is being loaded
|
||||||
- Merge results using online softmax
|
- Each layer only waits for its own data, not all layers' data
|
||||||
- Finally merge with decode buffer (accumulated decode tokens)
|
|
||||||
|
|
||||||
This approach is simpler and proven correct (prefill tests pass).
|
This reduces effective latency from O(num_layers * transfer_time) to
|
||||||
The only difference from prefill is the additional decode buffer
|
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||||
that stores new tokens generated during decode.
|
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -533,9 +534,16 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
|
||||||
|
|
||||||
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
|
if offload_engine.is_pipeline_active():
|
||||||
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
|
q_batched, cpu_block_table, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to original ring buffer pipeline
|
||||||
|
load_slots = offload_engine.decode_load_slots
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
block_size, last_block_valid_tokens
|
block_size, last_block_valid_tokens
|
||||||
@@ -652,3 +660,62 @@ class Attention(nn.Module):
|
|||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _decode_with_layer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||||
|
|
||||||
|
This method uses pre-loaded layer buffers instead of loading
|
||||||
|
blocks one by one. The pipeline loads the next layer's data
|
||||||
|
while the current layer computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
|
The key insight is that each layer needs the SAME blocks but from
|
||||||
|
different layers of CPU cache. By double-buffering and pipelining
|
||||||
|
across layers, we reduce total latency.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||||
|
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||||
|
|
||||||
|
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
|
total_tokens = num_blocks * block_size
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
if last_block_valid_tokens < block_size:
|
||||||
|
# Only use valid tokens from last block
|
||||||
|
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||||
|
# Flatten and truncate
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||||
|
else:
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||||
|
|
||||||
|
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||||
|
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||||
|
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||||
|
|
||||||
|
# Compute attention on all prefilled blocks at once
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
o_acc, lse_acc = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k_batched, prev_v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|||||||
Reference in New Issue
Block a user