[refactor] Translate into english, void Chinese due to claude.
This commit is contained in:
@@ -22,7 +22,7 @@ class Config:
|
|||||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||||
num_prefetch_blocks: int = 2 # Prefetch区的block数量,用于三区域GPU Buffer设计
|
num_prefetch_blocks: int = 2 # Number of prefetch blocks for three-region GPU buffer design
|
||||||
|
|
||||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
|
|||||||
@@ -123,9 +123,9 @@ class ModelRunner:
|
|||||||
num_gpu_blocks = max_gpu_blocks
|
num_gpu_blocks = max_gpu_blocks
|
||||||
|
|
||||||
if config.enable_cpu_offload:
|
if config.enable_cpu_offload:
|
||||||
# Ping-Pong设计:CPU是主存储,GPU是工作缓冲区
|
# Three-region design: CPU is primary storage, GPU is working buffer
|
||||||
# CPU blocks = 支持max_model_len所需的全部blocks(存储一个最大序列的完整KV)
|
# CPU blocks = all blocks needed to support max_model_len (stores complete KV for one max sequence)
|
||||||
# GPU blocks = Ping-Pong工作缓冲区(用户指定或自动)
|
# GPU blocks = three-region working buffer (user-specified or auto)
|
||||||
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||||
|
|
||||||
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
||||||
@@ -412,12 +412,12 @@ class ModelRunner:
|
|||||||
|
|
||||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if 三区域 mode should be used.
|
Check if three-region mode should be used.
|
||||||
|
|
||||||
Use 三区域 when:
|
Use three-region when:
|
||||||
- CPU offload is enabled
|
- CPU offload is enabled
|
||||||
- There are blocks on CPU (either allocated there or offloaded)
|
- There are blocks on CPU (either allocated there or offloaded)
|
||||||
- Sequence exceeds GPU Compute区 capacity
|
- Sequence exceeds GPU Compute region capacity
|
||||||
"""
|
"""
|
||||||
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
||||||
return False
|
return False
|
||||||
@@ -429,10 +429,10 @@ class ModelRunner:
|
|||||||
# Check if any blocks are on CPU
|
# Check if any blocks are on CPU
|
||||||
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
if cpu_blocks:
|
if cpu_blocks:
|
||||||
# Has CPU blocks - use 三区域
|
# Has CPU blocks - use three-region
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check if sequence needs more blocks than GPU Compute区 can hold
|
# Check if sequence needs more blocks than GPU Compute region can hold
|
||||||
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
|
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
|
||||||
if seq.num_blocks > compute_size:
|
if seq.num_blocks > compute_size:
|
||||||
# Needs chunked processing
|
# Needs chunked processing
|
||||||
@@ -630,17 +630,17 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run prefill with 三区域 GPU buffer (CPU is primary storage).
|
Run prefill with three-region GPU buffer (CPU is primary storage).
|
||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. All blocks are allocated to CPU (primary storage)
|
1. All blocks are allocated to CPU (primary storage)
|
||||||
2. Process tokens in chunks using Compute区 GPU buffer
|
2. Process tokens in chunks using Compute region GPU buffer
|
||||||
3. After each chunk, offload from Compute区 to CPU
|
3. After each chunk, offload from Compute region to CPU
|
||||||
4. Prefetch区 用于加载 previous KV(如果有的话)
|
4. Prefetch region is used to load previous KV (if any)
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
assert len(seqs) == 1, "三区域 prefill only supports single sequence"
|
assert len(seqs) == 1, "Three-region prefill only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
offload_engine = self.kvcache_manager.offload_engine
|
offload_engine = self.kvcache_manager.offload_engine
|
||||||
@@ -648,7 +648,7 @@ class ModelRunner:
|
|||||||
tokens_per_chunk = compute_size * self.block_size
|
tokens_per_chunk = compute_size * self.block_size
|
||||||
|
|
||||||
total_tokens = len(seq)
|
total_tokens = len(seq)
|
||||||
print(f"[三区域 Prefill] Starting: {total_tokens} tokens, "
|
print(f"[Three-region Prefill] Starting: {total_tokens} tokens, "
|
||||||
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||||
file=sys.stderr)
|
file=sys.stderr)
|
||||||
|
|
||||||
@@ -670,12 +670,12 @@ class ModelRunner:
|
|||||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||||
num_blocks = end_block_idx - start_block_idx
|
num_blocks = end_block_idx - start_block_idx
|
||||||
|
|
||||||
print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||||
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
||||||
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
||||||
file=sys.stderr)
|
file=sys.stderr)
|
||||||
|
|
||||||
# Get GPU slots for this chunk (使用 Compute区)
|
# Get GPU slots for this chunk (using Compute region)
|
||||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
@@ -695,7 +695,7 @@ class ModelRunner:
|
|||||||
logical_id = seq.block_table[i]
|
logical_id = seq.block_table[i]
|
||||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||||
|
|
||||||
# Offload this chunk from Compute区 to CPU (async)
|
# Offload this chunk from Compute region to CPU (async)
|
||||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||||
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
||||||
|
|
||||||
@@ -707,7 +707,7 @@ class ModelRunner:
|
|||||||
# Wait for all offloads to complete
|
# Wait for all offloads to complete
|
||||||
offload_engine.wait_all_offload_done()
|
offload_engine.wait_all_offload_done()
|
||||||
|
|
||||||
print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||||
|
|
||||||
# Sample from last logits
|
# Sample from last logits
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
@@ -776,14 +776,15 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run decode with 三区域 GPU buffer.
|
Run decode with three-region GPU buffer.
|
||||||
|
|
||||||
All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks.
|
All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks.
|
||||||
New token's KV is written to Decode区 (slot 0) then offloaded to CPU.
|
New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full.
|
||||||
|
|
||||||
关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV。
|
Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV.
|
||||||
|
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
||||||
"""
|
"""
|
||||||
assert len(seqs) == 1, "三区域 decode only supports single sequence"
|
assert len(seqs) == 1, "Three-region decode only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
offload_engine = self.kvcache_manager.offload_engine
|
offload_engine = self.kvcache_manager.offload_engine
|
||||||
@@ -792,13 +793,16 @@ class ModelRunner:
|
|||||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
|
||||||
# 使用 Decode区 (slot 0) 写入新 KV
|
# Use Decode region (slot 0) to write new KV
|
||||||
decode_slot = offload_engine.decode_slot # = 0
|
decode_slot = offload_engine.decode_slot # = 0
|
||||||
pos_in_block = (len(seq) - 1) % self.block_size
|
pos_in_block = (len(seq) - 1) % self.block_size
|
||||||
slot = decode_slot * self.block_size + pos_in_block
|
slot = decode_slot * self.block_size + pos_in_block
|
||||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
|
||||||
|
# Get decode start position for accumulated token tracking
|
||||||
|
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
|
||||||
# Set up context for chunked decode
|
# Set up context for chunked decode
|
||||||
set_context(
|
set_context(
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
@@ -808,17 +812,22 @@ class ModelRunner:
|
|||||||
offload_engine=self.kvcache_manager,
|
offload_engine=self.kvcache_manager,
|
||||||
chunked_seq=seq,
|
chunked_seq=seq,
|
||||||
decode_pos_in_block=pos_in_block,
|
decode_pos_in_block=pos_in_block,
|
||||||
|
decode_start_pos_in_block=decode_start_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run model forward pass
|
# Run model forward pass
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
# Offload new KV from Decode区 to CPU
|
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
# This avoids unnecessary offloading on every decode step
|
||||||
if last_cpu_block >= 0:
|
if pos_in_block == self.block_size - 1:
|
||||||
offload_engine.offload_decode_slot(last_cpu_block)
|
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||||
offload_engine.wait_all_offload_done()
|
if last_cpu_block >= 0:
|
||||||
|
offload_engine.offload_decode_slot(last_cpu_block)
|
||||||
|
offload_engine.wait_all_offload_done()
|
||||||
|
# Reset decode start position for next block
|
||||||
|
self.kvcache_manager.reset_decode_start_pos(seq)
|
||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
|
|||||||
@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU)
|
policy: Eviction policy (default: LRU)
|
||||||
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer.
|
cpu_primary: If True, use CPU as primary storage with three-region GPU buffer.
|
||||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||||
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design
|
num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||||
self.cpu_primary = cpu_primary # 三区域 mode flag
|
self.cpu_primary = cpu_primary # Three-region mode flag
|
||||||
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数
|
self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter
|
||||||
|
|
||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
@@ -138,6 +138,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Track blocks that have been prefilled (KV written) for chunked prefill
|
# Track blocks that have been prefilled (KV written) for chunked prefill
|
||||||
self.prefilled_blocks: Set[int] = set() # logical_ids
|
self.prefilled_blocks: Set[int] = set() # logical_ids
|
||||||
|
|
||||||
|
# Track decode starting position within block (for batched offload optimization)
|
||||||
|
# Key: sequence id, Value: starting position where decode began in current block
|
||||||
|
self._decode_start_pos: Dict[int, int] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
return self._block_size
|
return self._block_size
|
||||||
@@ -337,11 +341,11 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
assert not seq.block_table, "Sequence already has blocks"
|
assert not seq.block_table, "Sequence already has blocks"
|
||||||
|
|
||||||
# Ping-Pong模式:所有blocks都分配到CPU
|
# Three-region mode: all blocks are allocated to CPU
|
||||||
if self.cpu_primary:
|
if self.cpu_primary:
|
||||||
return self.allocate_cpu_only(seq)
|
return self.allocate_cpu_only(seq)
|
||||||
|
|
||||||
# Legacy模式:GPU为主,CPU为overflow
|
# Legacy mode: GPU as primary, CPU as overflow
|
||||||
h = -1
|
h = -1
|
||||||
cache_miss = False
|
cache_miss = False
|
||||||
|
|
||||||
@@ -467,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block.token_ids = []
|
block.token_ids = []
|
||||||
|
|
||||||
if self.cpu_primary:
|
if self.cpu_primary:
|
||||||
# Ping-Pong模式:新block分配到CPU
|
# Three-region mode: new block allocated to CPU
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
raise RuntimeError("No free CPU blocks for decode")
|
raise RuntimeError("No free CPU blocks for decode")
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||||
@@ -476,7 +480,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block.gpu_slot = -1
|
block.gpu_slot = -1
|
||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||||
else:
|
else:
|
||||||
# Legacy模式:新block分配到GPU
|
# Legacy mode: new block allocated to GPU
|
||||||
gpu_slot = self._allocate_gpu_slot()
|
gpu_slot = self._allocate_gpu_slot()
|
||||||
block.location = BlockLocation.GPU
|
block.location = BlockLocation.GPU
|
||||||
block.gpu_slot = gpu_slot
|
block.gpu_slot = gpu_slot
|
||||||
@@ -1021,22 +1025,22 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
break
|
break
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
# ========== Ping-Pong 双缓冲支持 ==========
|
# ========== Three-region double buffering support ==========
|
||||||
|
|
||||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||||
"""
|
"""
|
||||||
为序列分配 CPU blocks(用于 Ping-Pong 模式)。
|
Allocate CPU blocks for sequence (for three-region mode).
|
||||||
|
|
||||||
与 allocate() 不同,这里所有 blocks 都分配到 CPU,
|
Unlike allocate(), here all blocks are allocated to CPU,
|
||||||
GPU 只用作工作缓冲区。
|
GPU is only used as working buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 要分配的序列
|
seq: Sequence to allocate
|
||||||
"""
|
"""
|
||||||
assert not seq.block_table, "Sequence already has blocks"
|
assert not seq.block_table, "Sequence already has blocks"
|
||||||
|
|
||||||
for i in range(seq.num_blocks):
|
for i in range(seq.num_blocks):
|
||||||
# 分配 CPU block
|
# Allocate CPU block
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No free CPU blocks. Need {seq.num_blocks}, "
|
f"No free CPU blocks. Need {seq.num_blocks}, "
|
||||||
@@ -1045,7 +1049,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||||
|
|
||||||
# 分配逻辑块
|
# Allocate logical block
|
||||||
logical_id = self.free_logical_ids.popleft()
|
logical_id = self.free_logical_ids.popleft()
|
||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
block.ref_count = 1
|
block.ref_count = 1
|
||||||
@@ -1058,13 +1062,13 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
获取序列的 CPU block ID 列表。
|
Get CPU block ID list for sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 序列
|
seq: Sequence
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CPU block IDs 列表,按序列顺序
|
List of CPU block IDs in sequence order
|
||||||
"""
|
"""
|
||||||
cpu_blocks = []
|
cpu_blocks = []
|
||||||
for logical_id in seq.block_table:
|
for logical_id in seq.block_table:
|
||||||
@@ -1072,20 +1076,20 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
if block.location == BlockLocation.CPU:
|
if block.location == BlockLocation.CPU:
|
||||||
cpu_blocks.append(block.cpu_block_id)
|
cpu_blocks.append(block.cpu_block_id)
|
||||||
else:
|
else:
|
||||||
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
|
# If block is on GPU, it should have a corresponding CPU block
|
||||||
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
|
# In three-region mode, all data ultimately resides on CPU
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Block {logical_id} not on CPU (location={block.location}). "
|
f"Block {logical_id} not on CPU (location={block.location}). "
|
||||||
f"In Ping-Pong mode, all blocks should be on CPU."
|
f"In three-region mode, all blocks should be on CPU."
|
||||||
)
|
)
|
||||||
return cpu_blocks
|
return cpu_blocks
|
||||||
|
|
||||||
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||||
"""
|
"""
|
||||||
获取序列的所有 CPU blocks 及其逻辑 ID。
|
Get all CPU blocks and their logical IDs for sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 序列
|
seq: Sequence
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(cpu_block_ids, logical_ids)
|
(cpu_block_ids, logical_ids)
|
||||||
@@ -1101,13 +1105,13 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||||
"""
|
"""
|
||||||
为序列分配下一个 CPU block(用于 decode 时新 token)。
|
Allocate next CPU block for sequence (for new token during decode).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 序列
|
seq: Sequence
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
新分配的 CPU block ID
|
Newly allocated CPU block ID
|
||||||
"""
|
"""
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
raise RuntimeError("No free CPU blocks")
|
raise RuntimeError("No free CPU blocks")
|
||||||
@@ -1128,15 +1132,15 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
def get_last_cpu_block(self, seq: Sequence) -> int:
|
def get_last_cpu_block(self, seq: Sequence) -> int:
|
||||||
"""
|
"""
|
||||||
获取序列最后一个 block 的 CPU block ID。
|
Get CPU block ID of the last block in sequence.
|
||||||
|
|
||||||
如果最后一个 block 不在 CPU 上,返回 -1。
|
Returns -1 if the last block is not on CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 序列
|
seq: Sequence
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CPU block ID,如果不在 CPU 上则返回 -1
|
CPU block ID, or -1 if not on CPU
|
||||||
"""
|
"""
|
||||||
if not seq.block_table:
|
if not seq.block_table:
|
||||||
return -1
|
return -1
|
||||||
@@ -1150,19 +1154,65 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||||
"""
|
"""
|
||||||
获取三区域 decode 时新 KV 写入的 GPU slot。
|
Get GPU slot for writing new KV during three-region decode.
|
||||||
|
|
||||||
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。
|
In three-region design, always use Decode region (slot 0) to write new KV.
|
||||||
这样可以避免与 Compute/Prefetch区 的加载操作冲突。
|
This avoids conflicts with Compute/Prefetch region loading operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq: 序列
|
seq: Sequence
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GPU slot ID (永远是 decode_slot = 0)
|
GPU slot ID (always decode_slot = 0)
|
||||||
"""
|
"""
|
||||||
return self.offload_engine.decode_slot
|
return self.offload_engine.decode_slot
|
||||||
|
|
||||||
|
def get_decode_start_pos(self, seq: Sequence) -> int:
|
||||||
|
"""
|
||||||
|
Get the starting position within block where decode tokens began.
|
||||||
|
|
||||||
|
This is used for batched offload optimization - we need to attend to all
|
||||||
|
accumulated tokens in decode slot, not just the current one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Starting position within block (0 to block_size-1)
|
||||||
|
"""
|
||||||
|
seq_id = id(seq)
|
||||||
|
if seq_id not in self._decode_start_pos:
|
||||||
|
# First decode step - compute starting position
|
||||||
|
# After prefill, the last block has some tokens filled
|
||||||
|
# Decode starts at the next position
|
||||||
|
prefill_len = len(seq) - 1 # Current len includes the new decode token
|
||||||
|
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
||||||
|
return self._decode_start_pos[seq_id]
|
||||||
|
|
||||||
|
def reset_decode_start_pos(self, seq: Sequence) -> None:
|
||||||
|
"""
|
||||||
|
Reset decode start position for sequence.
|
||||||
|
|
||||||
|
Called when block is full and offloaded - next decode starts at position 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
"""
|
||||||
|
seq_id = id(seq)
|
||||||
|
self._decode_start_pos[seq_id] = 0
|
||||||
|
|
||||||
|
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||||
|
"""
|
||||||
|
Clear decode position tracking for sequence.
|
||||||
|
|
||||||
|
Called when sequence is deallocated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
"""
|
||||||
|
seq_id = id(seq)
|
||||||
|
self._decode_start_pos.pop(seq_id, None)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"HybridKVCacheManager(\n"
|
f"HybridKVCacheManager(\n"
|
||||||
|
|||||||
@@ -65,44 +65,44 @@ class OffloadEngine:
|
|||||||
self.kv_dim = num_kv_heads * head_dim
|
self.kv_dim = num_kv_heads * head_dim
|
||||||
self.block_numel = block_size * self.kv_dim
|
self.block_numel = block_size * self.kv_dim
|
||||||
|
|
||||||
# ========== 三区域 GPU Buffer 配置 ==========
|
# ========== Three-region GPU Buffer configuration ==========
|
||||||
# 约束检查
|
# Constraint checks
|
||||||
assert num_gpu_blocks >= 3, \
|
assert num_gpu_blocks >= 3, \
|
||||||
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||||
assert num_prefetch_blocks >= 1, \
|
assert num_prefetch_blocks >= 1, \
|
||||||
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
|
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
|
||||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||||
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||||
|
|
||||||
# 三区域配置
|
# Three-region configuration
|
||||||
# Decode区: [0] - 固定1个block用于写入新KV
|
# Decode region: [0] - Fixed 1 block for writing new KV
|
||||||
self.decode_slot = 0
|
self.decode_slot = 0
|
||||||
|
|
||||||
# Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||||
compute_start = 1
|
compute_start = 1
|
||||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||||
self.compute_slots = list(range(compute_start, compute_end))
|
self.compute_slots = list(range(compute_start, compute_end))
|
||||||
self.num_compute_blocks = len(self.compute_slots)
|
self.num_compute_blocks = len(self.compute_slots)
|
||||||
|
|
||||||
# Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||||
prefetch_start = compute_end
|
prefetch_start = compute_end
|
||||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||||
self.num_prefetch_blocks = num_prefetch_blocks
|
self.num_prefetch_blocks = num_prefetch_blocks
|
||||||
|
|
||||||
self.num_gpu_slots = num_gpu_blocks # alias
|
self.num_gpu_slots = num_gpu_blocks # alias
|
||||||
|
|
||||||
# 保留旧的ping/pong属性以兼容(后续会移除)
|
# Keep old ping/pong attributes for compatibility (will be removed later)
|
||||||
self.ping_size = self.num_compute_blocks
|
self.ping_size = self.num_compute_blocks
|
||||||
self.pong_size = self.num_prefetch_blocks
|
self.pong_size = self.num_prefetch_blocks
|
||||||
self.ping_slots = self.compute_slots.copy()
|
self.ping_slots = self.compute_slots.copy()
|
||||||
self.pong_slots = self.prefetch_slots.copy()
|
self.pong_slots = self.prefetch_slots.copy()
|
||||||
|
|
||||||
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
|
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||||
|
|
||||||
# ========== Fixed-address GPU KV cache ==========
|
# ========== Fixed-address GPU KV cache ==========
|
||||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
# 使用 zeros 初始化以避免未初始化内存问题
|
# Use zeros initialization to avoid uninitialized memory issues
|
||||||
self.k_cache_gpu = torch.zeros(
|
self.k_cache_gpu = torch.zeros(
|
||||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
dtype=dtype, device="cuda"
|
dtype=dtype, device="cuda"
|
||||||
@@ -140,15 +140,15 @@ class OffloadEngine:
|
|||||||
self.compute_stream = torch.cuda.current_stream()
|
self.compute_stream = torch.cuda.current_stream()
|
||||||
self._stream_idx = 0
|
self._stream_idx = 0
|
||||||
|
|
||||||
# ========== 三区域专用 stream 和事件 ==========
|
# ========== Three-region dedicated stream and events ==========
|
||||||
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
|
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
|
||||||
|
|
||||||
# 同步事件 - 三区域加载完成
|
# Sync events - three-region loading completion
|
||||||
self.compute_ready = torch.cuda.Event()
|
self.compute_ready = torch.cuda.Event()
|
||||||
self.prefetch_ready = torch.cuda.Event()
|
self.prefetch_ready = torch.cuda.Event()
|
||||||
self.decode_offload_done = torch.cuda.Event()
|
self.decode_offload_done = torch.cuda.Event()
|
||||||
|
|
||||||
# 保留旧的ping/pong事件以兼容(后续会移除)
|
# Keep old ping/pong events for compatibility (will be removed later)
|
||||||
self.pingpong_stream = self.transfer_stream_main
|
self.pingpong_stream = self.transfer_stream_main
|
||||||
self.ping_ready = self.compute_ready
|
self.ping_ready = self.compute_ready
|
||||||
self.pong_ready = self.prefetch_ready
|
self.pong_ready = self.prefetch_ready
|
||||||
@@ -568,20 +568,20 @@ class OffloadEngine:
|
|||||||
f" kv_heads={self.num_kv_heads},\n"
|
f" kv_heads={self.num_kv_heads},\n"
|
||||||
f" head_dim={self.head_dim},\n"
|
f" head_dim={self.head_dim},\n"
|
||||||
f" dtype={self.dtype},\n"
|
f" dtype={self.dtype},\n"
|
||||||
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== Ping-Pong 双缓冲方法 ==========
|
# ========== Ping-Pong double buffering methods ==========
|
||||||
|
|
||||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
异步加载CPU blocks到Ping buffer。
|
Async load CPU blocks to Ping buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_ids: 要加载的CPU block IDs列表
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
self.ping_ready.record(self.pingpong_stream)
|
self.ping_ready.record(self.pingpong_stream)
|
||||||
@@ -594,7 +594,7 @@ class OffloadEngine:
|
|||||||
for i in range(num_to_load):
|
for i in range(num_to_load):
|
||||||
cpu_id = cpu_block_ids[i]
|
cpu_id = cpu_block_ids[i]
|
||||||
gpu_slot = self.ping_slots[i]
|
gpu_slot = self.ping_slots[i]
|
||||||
# 所有层一起复制
|
# Copy all layers together
|
||||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -605,10 +605,10 @@ class OffloadEngine:
|
|||||||
|
|
||||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
异步加载CPU blocks到Pong buffer。
|
Async load CPU blocks to Pong buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_ids: 要加载的CPU block IDs列表
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
self.pong_ready.record(self.pingpong_stream)
|
self.pong_ready.record(self.pingpong_stream)
|
||||||
@@ -630,11 +630,11 @@ class OffloadEngine:
|
|||||||
self.pong_ready.record(self.pingpong_stream)
|
self.pong_ready.record(self.pingpong_stream)
|
||||||
|
|
||||||
def wait_ping(self) -> None:
|
def wait_ping(self) -> None:
|
||||||
"""等待Ping buffer加载完成。"""
|
"""Wait for Ping buffer loading to complete."""
|
||||||
self.compute_stream.wait_event(self.ping_ready)
|
self.compute_stream.wait_event(self.ping_ready)
|
||||||
|
|
||||||
def wait_pong(self) -> None:
|
def wait_pong(self) -> None:
|
||||||
"""等待Pong buffer加载完成。"""
|
"""Wait for Pong buffer loading to complete."""
|
||||||
self.compute_stream.wait_event(self.pong_ready)
|
self.compute_stream.wait_event(self.pong_ready)
|
||||||
|
|
||||||
def offload_buffer_to_cpu(
|
def offload_buffer_to_cpu(
|
||||||
@@ -643,11 +643,11 @@ class OffloadEngine:
|
|||||||
cpu_block_ids: List[int],
|
cpu_block_ids: List[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
异步将buffer中的KV offload到CPU。
|
Async offload KV from buffer to CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
buffer: "ping" 或 "pong"
|
buffer: "ping" or "pong"
|
||||||
cpu_block_ids: 目标CPU block IDs列表
|
cpu_block_ids: Target CPU block IDs list
|
||||||
"""
|
"""
|
||||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||||
@@ -660,7 +660,7 @@ class OffloadEngine:
|
|||||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||||
|
|
||||||
with torch.cuda.stream(self.pingpong_stream):
|
with torch.cuda.stream(self.pingpong_stream):
|
||||||
# 等待计算完成
|
# Wait for compute to complete
|
||||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
for i in range(num_to_offload):
|
for i in range(num_to_offload):
|
||||||
@@ -680,11 +680,11 @@ class OffloadEngine:
|
|||||||
cpu_block_id: int,
|
cpu_block_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
异步将单个GPU slot的KV offload到CPU。
|
Async offload a single GPU slot's KV to CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gpu_slot: GPU slot ID
|
gpu_slot: GPU slot ID
|
||||||
cpu_block_id: 目标CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||||
|
|
||||||
@@ -698,15 +698,15 @@ class OffloadEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def wait_ping_offload_done(self) -> None:
|
def wait_ping_offload_done(self) -> None:
|
||||||
"""等待Ping buffer offload完成。"""
|
"""Wait for Ping buffer offload to complete."""
|
||||||
self.compute_stream.wait_event(self.ping_offload_done)
|
self.compute_stream.wait_event(self.ping_offload_done)
|
||||||
|
|
||||||
def wait_pong_offload_done(self) -> None:
|
def wait_pong_offload_done(self) -> None:
|
||||||
"""等待Pong buffer offload完成。"""
|
"""Wait for Pong buffer offload to complete."""
|
||||||
self.compute_stream.wait_event(self.pong_offload_done)
|
self.compute_stream.wait_event(self.pong_offload_done)
|
||||||
|
|
||||||
def wait_all_offload_done(self) -> None:
|
def wait_all_offload_done(self) -> None:
|
||||||
"""等待所有offload完成。"""
|
"""Wait for all offload operations to complete."""
|
||||||
self.pingpong_stream.synchronize()
|
self.pingpong_stream.synchronize()
|
||||||
|
|
||||||
def get_kv_for_ping_slots(
|
def get_kv_for_ping_slots(
|
||||||
@@ -715,14 +715,14 @@ class OffloadEngine:
|
|||||||
num_slots: int,
|
num_slots: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取Ping buffer中指定数量slots的KV。
|
Get KV for specified number of slots in Ping buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
num_slots: 需要的slot数量
|
num_slots: Number of slots needed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
slots = self.ping_slots[:num_slots]
|
slots = self.ping_slots[:num_slots]
|
||||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||||
@@ -738,14 +738,14 @@ class OffloadEngine:
|
|||||||
num_slots: int,
|
num_slots: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取Pong buffer中指定数量slots的KV。
|
Get KV for specified number of slots in Pong buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
num_slots: 需要的slot数量
|
num_slots: Number of slots needed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
slots = self.pong_slots[:num_slots]
|
slots = self.pong_slots[:num_slots]
|
||||||
k = self.k_cache_gpu[layer_id, slots]
|
k = self.k_cache_gpu[layer_id, slots]
|
||||||
@@ -760,14 +760,14 @@ class OffloadEngine:
|
|||||||
gpu_slots: List[int],
|
gpu_slots: List[int],
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取指定GPU slots的KV。
|
Get KV for specified GPU slots.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
gpu_slots: GPU slot IDs列表
|
gpu_slots: List of GPU slot IDs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
if not gpu_slots:
|
if not gpu_slots:
|
||||||
return None, None
|
return None, None
|
||||||
@@ -777,14 +777,14 @@ class OffloadEngine:
|
|||||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ========== 三区域 GPU Buffer 方法 ==========
|
# ========== Three-region GPU Buffer methods ==========
|
||||||
|
|
||||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
异步加载CPU blocks到Compute区。
|
Async load CPU blocks to Compute region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_ids: 要加载的CPU block IDs列表
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
self.compute_ready.record(self.transfer_stream_main)
|
self.compute_ready.record(self.transfer_stream_main)
|
||||||
@@ -797,7 +797,7 @@ class OffloadEngine:
|
|||||||
for i in range(num_to_load):
|
for i in range(num_to_load):
|
||||||
cpu_id = cpu_block_ids[i]
|
cpu_id = cpu_block_ids[i]
|
||||||
gpu_slot = self.compute_slots[i]
|
gpu_slot = self.compute_slots[i]
|
||||||
# 所有层一起复制
|
# Copy all layers together
|
||||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -808,10 +808,10 @@ class OffloadEngine:
|
|||||||
|
|
||||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
异步加载CPU blocks到Prefetch区。
|
Async load CPU blocks to Prefetch region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_ids: 要加载的CPU block IDs列表
|
cpu_block_ids: List of CPU block IDs to load
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
self.prefetch_ready.record(self.transfer_stream_main)
|
self.prefetch_ready.record(self.transfer_stream_main)
|
||||||
@@ -833,25 +833,25 @@ class OffloadEngine:
|
|||||||
self.prefetch_ready.record(self.transfer_stream_main)
|
self.prefetch_ready.record(self.transfer_stream_main)
|
||||||
|
|
||||||
def wait_compute(self) -> None:
|
def wait_compute(self) -> None:
|
||||||
"""等待Compute区加载完成。"""
|
"""Wait for Compute region loading to complete."""
|
||||||
self.compute_stream.wait_event(self.compute_ready)
|
self.compute_stream.wait_event(self.compute_ready)
|
||||||
|
|
||||||
def wait_prefetch(self) -> None:
|
def wait_prefetch(self) -> None:
|
||||||
"""等待Prefetch区加载完成。"""
|
"""Wait for Prefetch region loading to complete."""
|
||||||
self.compute_stream.wait_event(self.prefetch_ready)
|
self.compute_stream.wait_event(self.prefetch_ready)
|
||||||
|
|
||||||
def swap_compute_prefetch(self) -> None:
|
def swap_compute_prefetch(self) -> None:
|
||||||
"""交换Compute区和Prefetch区的角色。"""
|
"""Swap roles of Compute region and Prefetch region."""
|
||||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||||
# 同时更新旧的ping/pong slots以保持兼容
|
# Also update old ping/pong slots for compatibility
|
||||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||||
|
|
||||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||||
"""
|
"""
|
||||||
将Decode区的KV offload到CPU。
|
Offload KV from Decode region to CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_id: 目标CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||||
|
|
||||||
@@ -866,7 +866,7 @@ class OffloadEngine:
|
|||||||
self.decode_offload_done.record(self.transfer_stream_main)
|
self.decode_offload_done.record(self.transfer_stream_main)
|
||||||
|
|
||||||
def wait_decode_offload(self) -> None:
|
def wait_decode_offload(self) -> None:
|
||||||
"""等待Decode区offload完成。"""
|
"""Wait for Decode region offload to complete."""
|
||||||
self.compute_stream.wait_event(self.decode_offload_done)
|
self.compute_stream.wait_event(self.decode_offload_done)
|
||||||
|
|
||||||
def get_kv_for_compute(
|
def get_kv_for_compute(
|
||||||
@@ -875,14 +875,14 @@ class OffloadEngine:
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取Compute区中指定数量blocks的KV。
|
Get KV for specified number of blocks in Compute region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
num_blocks: 需要的block数量
|
num_blocks: Number of blocks needed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
slots = self.compute_slots[:num_blocks]
|
slots = self.compute_slots[:num_blocks]
|
||||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||||
@@ -898,14 +898,14 @@ class OffloadEngine:
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取Prefetch区中指定数量blocks的KV。
|
Get KV for specified number of blocks in Prefetch region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
num_blocks: 需要的block数量
|
num_blocks: Number of blocks needed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
slots = self.prefetch_slots[:num_blocks]
|
slots = self.prefetch_slots[:num_blocks]
|
||||||
k = self.k_cache_gpu[layer_id, slots]
|
k = self.k_cache_gpu[layer_id, slots]
|
||||||
@@ -920,14 +920,14 @@ class OffloadEngine:
|
|||||||
pos_in_block: int,
|
pos_in_block: int,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
获取Decode区指定位置的KV(用于decode时的新token)。
|
Get KV at specified position in Decode region (for new token during decode).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_id: 层ID
|
layer_id: Layer ID
|
||||||
pos_in_block: token在block内的位置 (0 to block_size-1)
|
pos_in_block: Token position within block (0 to block_size-1)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(k_cache, v_cache),shape: [1, 1, kv_heads, head_dim]
|
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||||
@@ -935,12 +935,36 @@ class OffloadEngine:
|
|||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
def get_kv_for_decode_slot_accumulated(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
将Compute区的KV offload到CPU。
|
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
|
||||||
|
|
||||||
|
Used when batching decode offloads - attend to all accumulated tokens,
|
||||||
|
not just the current one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_ids: 目标CPU block IDs列表
|
layer_id: Layer ID
|
||||||
|
num_tokens: Number of accumulated tokens (1 to block_size)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
|
||||||
|
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||||
|
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Offload KV from Compute region to CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_ids: Target CPU block IDs list
|
||||||
"""
|
"""
|
||||||
if not cpu_block_ids:
|
if not cpu_block_ids:
|
||||||
return
|
return
|
||||||
@@ -949,7 +973,7 @@ class OffloadEngine:
|
|||||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
# 等待计算完成
|
# Wait for compute to complete
|
||||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
for i in range(num_to_offload):
|
for i in range(num_to_offload):
|
||||||
|
|||||||
@@ -100,16 +100,16 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute attention with 三区域 GPU buffer for chunked prefill.
|
Compute attention with three-region GPU buffer for chunked prefill.
|
||||||
|
|
||||||
For chunked prefill:
|
For chunked prefill:
|
||||||
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
|
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
|
||||||
2. Compute attention against previous KV chunks (no causal mask)
|
2. Compute attention against previous KV chunks (no causal mask)
|
||||||
3. Compute attention against current chunk's KV (causal)
|
3. Compute attention against current chunk's KV (causal)
|
||||||
4. Merge all results using online softmax
|
4. Merge all results using online softmax
|
||||||
|
|
||||||
三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区,
|
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
|
||||||
不会发生写入和加载区域重叠的问题。
|
from CPU to Prefetch region, so write and load regions never overlap.
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ class Attention(nn.Module):
|
|||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
# Load previous KV from CPU using Compute/Prefetch区
|
# Load previous KV from CPU using Compute/Prefetch region
|
||||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||||
kvcache_manager = context.offload_engine
|
kvcache_manager = context.offload_engine
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
@@ -133,12 +133,12 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
# 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突)
|
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
|
||||||
prefetch_size = offload_engine.num_prefetch_blocks
|
prefetch_size = offload_engine.num_prefetch_blocks
|
||||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||||||
use_compute = True # 交替使用 Compute区 和 Prefetch区
|
use_compute = True # Alternate between Compute region and Prefetch region
|
||||||
|
|
||||||
# 首先将 previous KV 加载到 Prefetch区
|
# First load previous KV to Prefetch region
|
||||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
||||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||||
@@ -157,14 +157,14 @@ class Attention(nn.Module):
|
|||||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||||
if use_compute:
|
if use_compute:
|
||||||
# 当前在 Prefetch区,下一个加载到 Compute区(如果有空间)
|
# Currently in Prefetch region, next load to Compute region (if space available)
|
||||||
# 注意:Compute区 此时已写入当前chunk的KV,不能覆盖
|
# Note: Compute region already has current chunk's KV written, cannot overwrite
|
||||||
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
|
# So here we use simple sync strategy: wait for current to complete before loading
|
||||||
pass # 简化版本:不进行双缓冲,只用 Prefetch区
|
pass # Simplified version: no double buffering, only use Prefetch region
|
||||||
else:
|
else:
|
||||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||||
|
|
||||||
# Wait for Prefetch区 and get KV
|
# Wait for Prefetch region and get KV
|
||||||
offload_engine.wait_prefetch()
|
offload_engine.wait_prefetch()
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
@@ -185,7 +185,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
# Load next chunk to Prefetch区 (if exists)
|
# Load next chunk to Prefetch region (if exists)
|
||||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||||
next_start = end
|
next_start = end
|
||||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||||
@@ -218,16 +218,16 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention with 三区域 GPU buffer.
|
Compute decode attention with three-region GPU buffer.
|
||||||
|
|
||||||
All KV is stored on CPU. Uses Compute区 buffer on GPU:
|
All KV is stored on CPU. Uses Compute region buffer on GPU:
|
||||||
1. Load chunk to Compute区
|
1. Load chunk to Compute region
|
||||||
2. Compute attention
|
2. Compute attention
|
||||||
3. Repeat for all chunks
|
3. Repeat for all chunks
|
||||||
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
|
4. Finally, attend to Decode region (slot 0) which contains the new token's KV
|
||||||
5. Merge all attention outputs using online softmax (LSE)
|
5. Merge all attention outputs using online softmax (LSE)
|
||||||
|
|
||||||
关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。
|
Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading.
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -246,10 +246,10 @@ class Attention(nn.Module):
|
|||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||||
|
|
||||||
# Get the actual offload_engine for 三区域 operations
|
# Get the actual offload_engine for three-region operations
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
# Calculate chunk info using Compute区
|
# Calculate chunk info using Compute region
|
||||||
compute_size = offload_engine.num_compute_blocks
|
compute_size = offload_engine.num_compute_blocks
|
||||||
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
||||||
|
|
||||||
@@ -262,12 +262,12 @@ class Attention(nn.Module):
|
|||||||
num_blocks_in_chunk = end - start
|
num_blocks_in_chunk = end - start
|
||||||
chunk_ids = cpu_block_table[start:end]
|
chunk_ids = cpu_block_table[start:end]
|
||||||
|
|
||||||
# Load this chunk to Compute区
|
# Load this chunk to Compute region
|
||||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
offload_engine.load_to_compute(chunk_ids)
|
offload_engine.load_to_compute(chunk_ids)
|
||||||
|
|
||||||
# Wait for Compute区 to be ready and get KV
|
# Wait for Compute region to be ready and get KV
|
||||||
offload_engine.wait_compute()
|
offload_engine.wait_compute()
|
||||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||||
self.layer_id, num_blocks_in_chunk
|
self.layer_id, num_blocks_in_chunk
|
||||||
@@ -286,21 +286,31 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||||
|
|
||||||
# Now attend to Decode区 (contains the new token's KV)
|
# Now attend to Decode region (contains accumulated decode tokens)
|
||||||
# This is the token being decoded - only 1 token at position pos_in_block
|
# When batching offloads, decode slot accumulates multiple tokens
|
||||||
|
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
|
||||||
pos_in_block = context.decode_pos_in_block
|
pos_in_block = context.decode_pos_in_block
|
||||||
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
|
start_pos = context.decode_start_pos_in_block
|
||||||
decode_o, decode_lse = flash_attn_with_lse(
|
num_accumulated = pos_in_block - start_pos + 1
|
||||||
q_batched, decode_k, decode_v,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge with accumulated
|
if num_accumulated > 0:
|
||||||
if o_acc is None:
|
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
|
||||||
o_acc = decode_o
|
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
else:
|
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
||||||
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
|
q_batched, decode_k, decode_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc = decode_o
|
||||||
|
else:
|
||||||
|
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||||
|
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ class RMSNorm(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||||
|
# Callers should reshape 3D tensors to 2D before calling
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.float()
|
x = x.float()
|
||||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
@@ -31,6 +33,7 @@ class RMSNorm(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.float().add_(residual.float())
|
x = x.float().add_(residual.float())
|
||||||
residual = x.to(orig_dtype)
|
residual = x.to(orig_dtype)
|
||||||
|
|||||||
@@ -79,8 +79,12 @@ class Qwen3Attention(nn.Module):
|
|||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
if not self.qkv_bias:
|
if not self.qkv_bias:
|
||||||
q = self.q_norm(q)
|
# Reshape to 2D before RMSNorm to avoid torch.compile recompilation
|
||||||
k = self.k_norm(k)
|
# q: [num_tokens, num_heads, head_dim] -> [num_tokens * num_heads, head_dim]
|
||||||
|
# After norm, reshape back to 3D
|
||||||
|
num_tokens = q.shape[0]
|
||||||
|
q = self.q_norm(q.reshape(-1, self.head_dim)).view(num_tokens, self.num_heads, self.head_dim)
|
||||||
|
k = self.k_norm(k.reshape(-1, self.head_dim)).view(num_tokens, self.num_kv_heads, self.head_dim)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
o = self.attn(q, k, v)
|
o = self.attn(q, k, v)
|
||||||
output = self.o_proj(o.flatten(1, -1))
|
output = self.o_proj(o.flatten(1, -1))
|
||||||
|
|||||||
@@ -27,8 +27,11 @@ class Context:
|
|||||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||||
# Current sequence being processed (for chunked prefill to load KV)
|
# Current sequence being processed (for chunked prefill to load KV)
|
||||||
chunked_seq: Any = None
|
chunked_seq: Any = None
|
||||||
# Position within block for decode (used for reading from Decode区)
|
# Position within block for decode (used for reading from Decode region)
|
||||||
decode_pos_in_block: int = 0
|
decode_pos_in_block: int = 0
|
||||||
|
# Starting position within block where decode tokens began (for accumulated token tracking)
|
||||||
|
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||||
|
decode_start_pos_in_block: int = 0
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
@@ -53,6 +56,7 @@ def set_context(
|
|||||||
offload_engine=None,
|
offload_engine=None,
|
||||||
chunked_seq=None,
|
chunked_seq=None,
|
||||||
decode_pos_in_block=0,
|
decode_pos_in_block=0,
|
||||||
|
decode_start_pos_in_block=0,
|
||||||
):
|
):
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context(
|
_CONTEXT = Context(
|
||||||
@@ -70,6 +74,7 @@ def set_context(
|
|||||||
offload_engine=offload_engine,
|
offload_engine=offload_engine,
|
||||||
chunked_seq=chunked_seq,
|
chunked_seq=chunked_seq,
|
||||||
decode_pos_in_block=decode_pos_in_block,
|
decode_pos_in_block=decode_pos_in_block,
|
||||||
|
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ Attention mechanisms allow models to focus on relevant parts of the input.
|
|||||||
return "".join(prompt_parts)
|
return "".join(prompt_parts)
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
|
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_prefetch_blocks=2):
|
||||||
"""Test chunked prefill with limited GPU blocks."""
|
"""Test chunked prefill with limited GPU blocks."""
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
|
|
||||||
@@ -75,15 +75,17 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
|
|||||||
print(f"=" * 60)
|
print(f"=" * 60)
|
||||||
print(f" target_input_len: ~{input_len} tokens")
|
print(f" target_input_len: ~{input_len} tokens")
|
||||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||||
|
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=16 * 1024,
|
max_model_len=128 * 1024,
|
||||||
max_num_batched_tokens=16 * 1024,
|
max_num_batched_tokens=128 * 1024,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_prefetch_blocks=num_prefetch_blocks,
|
||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -104,7 +106,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
|
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_prefetch_blocks=2):
|
||||||
"""Test chunked decode with limited GPU blocks."""
|
"""Test chunked decode with limited GPU blocks."""
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
|
|
||||||
@@ -114,15 +116,17 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
|
|||||||
print(f" target_input_len: ~{input_len} tokens")
|
print(f" target_input_len: ~{input_len} tokens")
|
||||||
print(f" output_len: {output_len} tokens")
|
print(f" output_len: {output_len} tokens")
|
||||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||||
|
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=16 * 1024,
|
max_model_len=128 * 1024,
|
||||||
max_num_batched_tokens=16 * 1024,
|
max_num_batched_tokens=128 * 1024,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_prefetch_blocks=num_prefetch_blocks,
|
||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -144,9 +148,10 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse arguments
|
# Parse arguments: num_gpu_blocks input_len output_len [num_prefetch_blocks]
|
||||||
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
||||||
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
|
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
|
||||||
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64
|
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64
|
||||||
|
num_prefetch_blocks = int(sys.argv[4]) if len(sys.argv) > 4 else 2
|
||||||
|
|
||||||
test_chunked_prefill(num_gpu_blocks, input_len, output_len)
|
test_chunked_prefill(num_gpu_blocks, input_len, output_len, num_prefetch_blocks)
|
||||||
Reference in New Issue
Block a user