[refactor] Translate into english, void Chinese due to claude.

This commit is contained in:
Zijie Tian
2025-12-11 00:30:24 +08:00
parent e85c2b4776
commit babfa17354
9 changed files with 297 additions and 187 deletions

View File

@@ -100,16 +100,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with 三区域 GPU buffer for chunked prefill.
Compute attention with three-region GPU buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU using Compute/Prefetch (if any previous chunks)
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax
三区域设计保证当前chunk的KV在Compute区previous KV从CPU加载到Prefetch区
不会发生写入和加载区域重叠的问题。
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
from CPU to Prefetch region, so write and load regions never overlap.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -122,7 +122,7 @@ class Attention(nn.Module):
o_acc = None
lse_acc = None
# Load previous KV from CPU using Compute/Prefetch
# Load previous KV from CPU using Compute/Prefetch region
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
@@ -133,12 +133,12 @@ class Attention(nn.Module):
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# 使用 Prefetch区 来加载 previous KV不会与当前 Compute区 冲突)
# Use Prefetch region to load previous KV (won't conflict with current Compute region)
prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
use_compute = True # 交替使用 Compute区 和 Prefetch
use_compute = True # Alternate between Compute region and Prefetch region
# 首先将 previous KV 加载到 Prefetch
# First load previous KV to Prefetch region
# Only layer 0 triggers the load (loads ALL layers at once)
first_chunk_end = min(prefetch_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
@@ -157,14 +157,14 @@ class Attention(nn.Module):
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# 当前在 Prefetch区下一个加载到 Compute区如果有空间
# 注意:Compute区 此时已写入当前chunk的KV不能覆盖
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
pass # 简化版本:不进行双缓冲,只用 Prefetch
# Currently in Prefetch region, next load to Compute region (if space available)
# Note: Compute region already has current chunk's KV written, cannot overwrite
# So here we use simple sync strategy: wait for current to complete before loading
pass # Simplified version: no double buffering, only use Prefetch region
else:
offload_engine.load_to_prefetch(next_chunk_ids)
# Wait for Prefetch and get KV
# Wait for Prefetch region and get KV
offload_engine.wait_prefetch()
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
@@ -185,7 +185,7 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Load next chunk to Prefetch (if exists)
# Load next chunk to Prefetch region (if exists)
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table))
@@ -218,16 +218,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with 三区域 GPU buffer.
Compute decode attention with three-region GPU buffer.
All KV is stored on CPU. Uses Compute buffer on GPU:
1. Load chunk to Compute
All KV is stored on CPU. Uses Compute region buffer on GPU:
1. Load chunk to Compute region
2. Compute attention
3. Repeat for all chunks
4. Finally, attend to Decode (slot 0) which contains the new token's KV
4. Finally, attend to Decode region (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE)
关键新token的KV在Decode区(slot 0)不会被Compute区的加载覆盖。
Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -246,10 +246,10 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for 三区域 operations
# Get the actual offload_engine for three-region operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info using Compute
# Calculate chunk info using Compute region
compute_size = offload_engine.num_compute_blocks
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
@@ -262,12 +262,12 @@ class Attention(nn.Module):
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Load this chunk to Compute
# Load this chunk to Compute region
# Only layer 0 triggers the load (loads ALL layers at once)
if self.layer_id == 0:
offload_engine.load_to_compute(chunk_ids)
# Wait for Compute to be ready and get KV
# Wait for Compute region to be ready and get KV
offload_engine.wait_compute()
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
@@ -286,21 +286,31 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Now attend to Decode (contains the new token's KV)
# This is the token being decoded - only 1 token at position pos_in_block
# Now attend to Decode region (contains accumulated decode tokens)
# When batching offloads, decode slot accumulates multiple tokens
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
pos_in_block = context.decode_pos_in_block
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
# Merge with accumulated
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if num_accumulated > 0:
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim]
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")