[claudesquad] update from 'fix-ga-perf-2' on 09 Jan 26 14:08 CST

This commit is contained in:
Zijie Tian
2026-01-09 14:08:12 +08:00
parent 79c4df4a27
commit 47e3e465f0
4 changed files with 628 additions and 278 deletions

View File

@@ -429,7 +429,14 @@ class ModelRunner:
else:
return self.run_layerwise_offload_decode(seqs)
#> Following Code will not use Layer-wise Offload mode
#> Check if contiguous GPU mode should be used (single-seq optimization)
if self._should_use_contiguous_gpu_mode(seqs, is_prefill):
if is_prefill:
return self.run_gpu_only_prefill(seqs)
else:
return self.run_gpu_only_decode(seqs)
#> Following Code uses standard PagedAttention path
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
@@ -437,6 +444,257 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_contiguous_gpu_mode(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if contiguous GPU mode should be used for single-seq optimization.
Conditions:
1. Has kvcache_manager with contiguous cache allocated
2. Not using CPU offload (no offload_engine)
3. Single sequence (batch_size == 1)
4. Has blocks allocated (not warmup)
"""
# Must have kvcache_manager
if not hasattr(self, 'kvcache_manager') or self.kvcache_manager is None:
return False
# Must have contiguous cache
if not hasattr(self.kvcache_manager, 'contiguous_k_cache'):
return False
if self.kvcache_manager.contiguous_k_cache is None:
return False
# Must NOT be offload mode
if hasattr(self.kvcache_manager, 'offload_engine'):
return False
# Single sequence only
if len(seqs) != 1:
return False
# Has blocks allocated (not warmup)
if not seqs[0].block_table:
return False
return True
# ========== Contiguous GPU-only Methods ==========
@torch.inference_mode()
def run_gpu_only_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
GPU-only prefill with contiguous KV cache layout.
Mirrors run_layerwise_offload_prefill() but stores to GPU instead of CPU.
No scatter operations - just contiguous slice assignment.
Key design:
- Process layer-by-layer (not via Attention.forward())
- Store K,V to contiguous GPU cache (same layout as computed K,V)
- Use sparse prefill attention if enabled
"""
assert len(seqs) == 1, "GPU-only layer-wise prefill only supports single sequence"
seq = seqs[0]
num_layers = len(self.model.model.layers)
total_tokens = len(seq)
logger.debug(f"[GPU-only Prefill] Starting: {total_tokens} tokens, {num_layers} layers")
# Get contiguous GPU cache
k_cache = self.kvcache_manager.contiguous_k_cache
v_cache = self.kvcache_manager.contiguous_v_cache
# Prepare inputs
input_ids = torch.tensor(seq[:], dtype=torch.int64, device="cuda")
positions = torch.arange(total_tokens, dtype=torch.int64, device="cuda")
# Import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda")
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Layer-by-layer processing
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
num_tokens = q.shape[0]
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(num_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim))
k = k.view(num_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention (uses k, v directly - before store!)
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Store K,V to contiguous GPU cache AFTER attention (same as offload pattern)
k_cache[layer_id, :total_tokens] = k
v_cache[layer_id, :total_tokens] = v
# Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Compute logits for last token
logits = self.model.compute_logits(hidden_states[-1:])
# Record prefill length for decode
self.kvcache_manager.contiguous_seq_len = total_tokens
logger.debug(f"[GPU-only Prefill] Complete: {num_layers} layers processed")
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
@torch.inference_mode()
def run_gpu_only_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Decode using contiguous GPU KV cache.
Similar to offload decode but simpler - all KV already on GPU.
"""
assert len(seqs) == 1, "GPU-only decode only supports single sequence"
seq = seqs[0]
num_layers = len(self.model.model.layers)
k_cache = self.kvcache_manager.contiguous_k_cache
v_cache = self.kvcache_manager.contiguous_v_cache
context_len = self.kvcache_manager.contiguous_seq_len
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# Store new K,V to cache
k_cache[layer_id, context_len] = k_new.squeeze(0)
v_cache[layer_id, context_len] = v_new.squeeze(0)
# Full K,V for attention (including new token)
k_full = k_cache[layer_id, :context_len + 1]
v_full = v_cache[layer_id, :context_len + 1]
# Attention
cu_seqlens_k = torch.tensor([0, context_len + 1], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len + 1,
softmax_scale=layer.self_attn.attn.scale,
causal=False, # Single query, no causal needed
)
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Update context length
self.kvcache_manager.contiguous_seq_len = context_len + 1
# Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Compute logits
logits = self.model.compute_logits(hidden_states)
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
def _should_use_layerwise_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if layer-wise offload mode should be used.