[refactor] Refactor offload code to multi-chunk.

This commit is contained in:
Zijie Tian
2025-12-15 01:13:58 +08:00
parent 5949537faf
commit 1081ab51ea
7 changed files with 36 additions and 233 deletions

View File

@@ -123,8 +123,7 @@ class Attention(nn.Module):
lse_acc = None
# Load previous KV from CPU using Compute/Prefetch region
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
@@ -215,7 +214,7 @@ class Attention(nn.Module):
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
kvcache_manager = context.offload_engine
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence