[refactor] Implement real chunked prefill mechenism.
This commit is contained in:
@@ -429,37 +429,26 @@ class ModelRunner:
|
|||||||
Run decode with chunked attention when sequence exceeds GPU capacity.
|
Run decode with chunked attention when sequence exceeds GPU capacity.
|
||||||
|
|
||||||
For decode, we need attention over ALL previous tokens. With CPU offload,
|
For decode, we need attention over ALL previous tokens. With CPU offload,
|
||||||
we load KV chunks and compute attention incrementally.
|
we load KV chunks and compute attention incrementally per-layer.
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Ensure last block is on GPU (for writing new KV token)
|
||||||
|
2. Run model forward - each attention layer:
|
||||||
|
a. Compute attention on GPU blocks
|
||||||
|
b. Load CPU blocks in chunks, compute + merge
|
||||||
|
3. Sample from output
|
||||||
|
"""
|
||||||
# Currently only supporting single sequence for chunked decode
|
# Currently only supporting single sequence for chunked decode
|
||||||
assert len(seqs) == 1, "Chunked decode only supports single sequence"
|
assert len(seqs) == 1, "Chunked decode only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
total_blocks = len(seq.block_table)
|
|
||||||
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
|
|
||||||
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
|
||||||
# Compute slot mapping for the new token
|
# Ensure last block is on GPU for writing new KV token
|
||||||
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
|
last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
|
||||||
last_logical_id = seq.block_table[-1]
|
slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
||||||
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
|
|
||||||
|
|
||||||
if last_block.location.name == "GPU":
|
|
||||||
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
|
||||||
else:
|
|
||||||
# Last block is on CPU - we need to bring it to GPU for writing the new token
|
|
||||||
# This is a special case - allocate a temporary GPU slot
|
|
||||||
# For simplicity, use a fixed slot (this might conflict, but for decode
|
|
||||||
# we only write 1 token so it should be ok)
|
|
||||||
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
|
|
||||||
slot = 0 # Use first slot temporarily
|
|
||||||
|
|
||||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
|
||||||
@@ -468,12 +457,13 @@ class ModelRunner:
|
|||||||
is_prefill=False, # Decode mode
|
is_prefill=False, # Decode mode
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
context_lens=context_len,
|
context_lens=context_len,
|
||||||
is_chunked_prefill=True, # Use chunked attention
|
is_chunked_prefill=True, # Use chunked attention path
|
||||||
offload_engine=self.kvcache_manager,
|
offload_engine=self.kvcache_manager,
|
||||||
chunked_seq=seq,
|
chunked_seq=seq,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run model forward pass
|
# Run model forward pass
|
||||||
|
# Each attention layer will handle chunked KV loading internally
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
|
|||||||
@@ -566,50 +566,151 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
cpu_blocks += 1
|
cpu_blocks += 1
|
||||||
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
|
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
|
||||||
|
|
||||||
def load_all_kv_for_layer(
|
# ========== Chunked Decode Support ==========
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Load ALL KV for a sequence from both GPU and CPU for a layer.
|
|
||||||
|
|
||||||
Used during chunked decode to compute full attention.
|
def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]:
|
||||||
|
"""
|
||||||
|
Get information for chunked decode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim]
|
(cpu_block_ids, cpu_logical_ids, num_chunks)
|
||||||
|
- cpu_block_ids: List of CPU block IDs in sequence order
|
||||||
|
- cpu_logical_ids: Corresponding logical block IDs
|
||||||
|
- num_chunks: Number of chunks needed
|
||||||
"""
|
"""
|
||||||
k_chunks = []
|
cpu_block_ids = []
|
||||||
v_chunks = []
|
cpu_logical_ids = []
|
||||||
|
|
||||||
for logical_id in seq.block_table:
|
for logical_id in seq.block_table:
|
||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
|
if block.location == BlockLocation.CPU:
|
||||||
|
cpu_block_ids.append(block.cpu_block_id)
|
||||||
|
cpu_logical_ids.append(logical_id)
|
||||||
|
|
||||||
|
# Each chunk uses available GPU slots minus 1 (reserved for write block)
|
||||||
|
usable_slots = self.num_gpu_slots - 1
|
||||||
|
num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0
|
||||||
|
|
||||||
|
return cpu_block_ids, cpu_logical_ids, num_chunks
|
||||||
|
|
||||||
|
def load_decode_chunk(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
cpu_block_ids: List[int],
|
||||||
|
cpu_logical_ids: List[int],
|
||||||
|
chunk_idx: int,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Load one chunk of CPU blocks to GPU for chunked decode.
|
||||||
|
|
||||||
|
Similar to chunked prefill: uses GPU slots to hold a batch of blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence being decoded
|
||||||
|
cpu_block_ids: All CPU block IDs for this sequence
|
||||||
|
cpu_logical_ids: Corresponding logical block IDs
|
||||||
|
chunk_idx: Which chunk to load (0-indexed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GPU slot IDs where the chunk was loaded
|
||||||
|
"""
|
||||||
|
chunk_size = self.num_gpu_slots
|
||||||
|
start = chunk_idx * chunk_size
|
||||||
|
end = min(start + chunk_size, len(cpu_block_ids))
|
||||||
|
|
||||||
|
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||||
|
chunk_logical_ids = cpu_logical_ids[start:end]
|
||||||
|
|
||||||
|
# Use GPU slots 0, 1, 2, ... for this chunk
|
||||||
|
gpu_slots = list(range(len(chunk_cpu_ids)))
|
||||||
|
|
||||||
|
# Load all layers at once using offload_engine
|
||||||
|
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
||||||
|
chunk_cpu_ids, gpu_slots
|
||||||
|
)
|
||||||
|
|
||||||
|
return gpu_slots
|
||||||
|
|
||||||
|
def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||||
|
"""
|
||||||
|
Get blocks currently on GPU for this sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs
|
||||||
|
"""
|
||||||
|
gpu_slots = []
|
||||||
|
logical_ids = []
|
||||||
|
|
||||||
|
for logical_id in seq.block_table:
|
||||||
|
block = self.logical_blocks[logical_id]
|
||||||
if block.location == BlockLocation.GPU:
|
if block.location == BlockLocation.GPU:
|
||||||
# Get from GPU cache
|
gpu_slots.append(block.gpu_slot)
|
||||||
k, v = self.offload_engine.get_layer_cache(layer_id)
|
logical_ids.append(logical_id)
|
||||||
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
|
|
||||||
v_block = v[block.gpu_slot]
|
|
||||||
k_chunks.append(k_block)
|
|
||||||
v_chunks.append(v_block)
|
|
||||||
|
|
||||||
elif block.location == BlockLocation.CPU:
|
return gpu_slots, logical_ids
|
||||||
# Get from CPU cache
|
|
||||||
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
|
|
||||||
# Already [block_size, kv_heads, head_dim]
|
|
||||||
k_chunks.append(k_block.to("cuda", non_blocking=True))
|
|
||||||
v_chunks.append(v_block.to("cuda", non_blocking=True))
|
|
||||||
|
|
||||||
# Concatenate all chunks
|
def get_kv_for_gpu_slots(
|
||||||
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim]
|
self,
|
||||||
v_all = torch.cat(v_chunks, dim=0)
|
layer_id: int,
|
||||||
|
gpu_slots: List[int],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get KV tensors for specific GPU slots.
|
||||||
|
|
||||||
# Add batch dimension
|
Args:
|
||||||
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim]
|
layer_id: Layer index
|
||||||
v_all = v_all.unsqueeze(0)
|
gpu_slots: List of GPU slot IDs
|
||||||
|
|
||||||
return k_all, v_all
|
Returns:
|
||||||
|
(k, v) tensors with shape [1, num_tokens, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id)
|
||||||
|
# k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
|
|
||||||
|
k_chunks = [k_cache[slot] for slot in gpu_slots]
|
||||||
|
v_chunks = [v_cache[slot] for slot in gpu_slots]
|
||||||
|
|
||||||
|
# Concatenate and add batch dimension
|
||||||
|
k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim]
|
||||||
|
v = torch.cat(v_chunks, dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def ensure_last_block_on_gpu(self, seq: Sequence) -> int:
|
||||||
|
"""
|
||||||
|
Ensure the last block is on GPU for writing new KV.
|
||||||
|
|
||||||
|
Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode
|
||||||
|
which uses slots 0, 1, 2, ... for loading CPU blocks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPU slot ID for the last block
|
||||||
|
"""
|
||||||
|
last_logical_id = seq.block_table[-1]
|
||||||
|
block = self.logical_blocks[last_logical_id]
|
||||||
|
|
||||||
|
if block.location == BlockLocation.GPU:
|
||||||
|
return block.gpu_slot
|
||||||
|
|
||||||
|
# Use last slot as reserved slot for write block
|
||||||
|
# This avoids conflicts with chunked decode which uses slots 0, 1, 2...
|
||||||
|
reserved_slot = self.num_gpu_slots - 1
|
||||||
|
|
||||||
|
# Load this block to GPU for all layers
|
||||||
|
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
||||||
|
[block.cpu_block_id], [reserved_slot]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update block state
|
||||||
|
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||||
|
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||||
|
|
||||||
|
self.gpu_slot_to_logical[reserved_slot] = last_logical_id
|
||||||
|
block.location = BlockLocation.GPU
|
||||||
|
block.gpu_slot = reserved_slot
|
||||||
|
block.cpu_block_id = -1
|
||||||
|
|
||||||
|
return reserved_slot
|
||||||
|
|
||||||
def get_gpu_block_tables(
|
def get_gpu_block_tables(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -308,6 +308,112 @@ class OffloadEngine:
|
|||||||
events.append(event)
|
events.append(event)
|
||||||
return events
|
return events
|
||||||
|
|
||||||
|
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
||||||
|
|
||||||
|
def load_cpu_blocks_to_gpu_slots(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_ids: List[int],
|
||||||
|
gpu_slot_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load CPU blocks to specific GPU slots for chunked decode.
|
||||||
|
|
||||||
|
Uses the main GPU KV cache slots, not a separate temp buffer.
|
||||||
|
This is the same mechanism as chunked prefill uses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
|
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
||||||
|
"""
|
||||||
|
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||||
|
|
||||||
|
stream = self._get_next_stream()
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||||
|
# Copy from pinned CPU memory to GPU KV cache slot
|
||||||
|
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for transfer to complete
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
def load_cpu_blocks_to_gpu_slots_async(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_ids: List[int],
|
||||||
|
gpu_slot_ids: List[int],
|
||||||
|
) -> torch.cuda.Event:
|
||||||
|
"""
|
||||||
|
Async version: Load CPU blocks to GPU slots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
|
gpu_slot_ids: List of GPU slot IDs to load into
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CUDA event to wait on
|
||||||
|
"""
|
||||||
|
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||||
|
|
||||||
|
stream = self._get_next_stream()
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||||
|
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
event.record()
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
def load_cpu_blocks_to_gpu_slots_all_layers(
|
||||||
|
self,
|
||||||
|
cpu_block_ids: List[int],
|
||||||
|
gpu_slot_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load CPU blocks to GPU slots for ALL layers at once.
|
||||||
|
|
||||||
|
More efficient than per-layer loading when we know the mapping upfront.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
|
gpu_slot_ids: List of GPU slot IDs to load into
|
||||||
|
"""
|
||||||
|
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||||
|
|
||||||
|
stream = self._get_next_stream()
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||||
|
# Copy all layers at once
|
||||||
|
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||||
|
self.k_cache_cpu[:, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||||
|
self.v_cache_cpu[:, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
# ========== Synchronization methods ==========
|
# ========== Synchronization methods ==========
|
||||||
|
|
||||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
||||||
|
|||||||
@@ -174,45 +174,76 @@ class Attention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Compute decode attention with KV spread across CPU and GPU.
|
Compute decode attention with KV spread across CPU and GPU.
|
||||||
|
|
||||||
For decode with chunked KV:
|
Uses chunked attention similar to chunked prefill:
|
||||||
1. Load all KV for this layer from CPU+GPU
|
1. Process blocks on GPU first (if any)
|
||||||
2. Compute attention (1 query token vs all KV)
|
2. Load CPU blocks in chunks to GPU slots (per-layer)
|
||||||
3. Return output
|
3. Compute attention for each chunk, merge with online softmax
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
# We need to attend to ALL previous tokens
|
# Need: [batch, seqlen, heads, dim]
|
||||||
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
# Load all KV for this layer
|
kvcache_manager = context.offload_engine
|
||||||
if context.offload_engine is not None and self.layer_id >= 0:
|
seq = context.chunked_seq
|
||||||
kvcache_manager = context.offload_engine
|
|
||||||
|
|
||||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
o_acc = None
|
||||||
# Load all KV from both GPU and CPU for this layer
|
lse_acc = None
|
||||||
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
|
|
||||||
context.chunked_seq,
|
# Step 1: Process blocks already on GPU (if any)
|
||||||
|
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
|
||||||
|
if gpu_slots:
|
||||||
|
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
|
||||||
|
o_gpu, lse_gpu = flash_attn_with_lse(
|
||||||
|
q_batched, k_gpu, v_gpu,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
o_acc, lse_acc = o_gpu, lse_gpu
|
||||||
|
|
||||||
|
# Step 2: Process CPU blocks in chunks
|
||||||
|
# Get chunk info from kvcache_manager
|
||||||
|
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
|
||||||
|
|
||||||
|
if num_chunks > 0:
|
||||||
|
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
|
||||||
|
chunk_size = kvcache_manager.num_gpu_slots - 1
|
||||||
|
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start = chunk_idx * chunk_size
|
||||||
|
end = min(start + chunk_size, len(cpu_block_ids))
|
||||||
|
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||||
|
|
||||||
|
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
|
||||||
|
# (slot num_gpu_slots-1 is reserved for write block)
|
||||||
|
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
|
||||||
|
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
|
||||||
self.layer_id,
|
self.layer_id,
|
||||||
|
chunk_cpu_ids,
|
||||||
|
gpu_slots_for_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
if k_all is not None and v_all is not None:
|
# Get KV for this chunk
|
||||||
# q shape: [batch_size, num_heads, head_dim]
|
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
|
||||||
# Need: [batch, seqlen, heads, dim]
|
self.layer_id, gpu_slots_for_chunk
|
||||||
# Insert seqlen dimension at position 1
|
)
|
||||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
|
||||||
|
|
||||||
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim]
|
# Compute attention for this chunk
|
||||||
# Compute attention (no causal mask for decode - we want all KV)
|
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||||
out, _ = flash_attn_with_lse(
|
q_batched, k_chunk, v_chunk,
|
||||||
q_batched,
|
softmax_scale=self.scale,
|
||||||
k_all,
|
causal=False,
|
||||||
v_all,
|
)
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False, # No causal mask for decode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
# Merge with accumulated
|
||||||
return out.squeeze(1)
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o_chunk, lse_chunk
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||||
|
|
||||||
# Fallback: shouldn't reach here
|
if o_acc is None:
|
||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
|
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||||
|
return o_acc.squeeze(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user