[refactor] Refactor offload code to multi-chunk.
This commit is contained in:
@@ -123,8 +123,7 @@ class Attention(nn.Module):
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Compute/Prefetch region
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
@@ -215,7 +214,7 @@ class Attention(nn.Module):
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
kvcache_manager = context.offload_engine
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
|
||||
Reference in New Issue
Block a user