[claudesquad] update from 'fix-bug-2' on 09 Jan 26 16:05 CST
This commit is contained in:
@@ -45,14 +45,7 @@ class ModelRunner:
|
||||
self.allocate_kv_cache()
|
||||
if not self.enforce_eager:
|
||||
if config.enable_cpu_offload:
|
||||
# TODO: Implement capture_offload_cudagraph() for offload mode
|
||||
# For now, offload mode uses eager execution
|
||||
# The standard capture_cudagraph() cannot be used because:
|
||||
# - It captures the PagedAttention decode path via Attention.forward()
|
||||
# - In offload mode, Attention.k_cache/v_cache are empty (KV is in ring buffer)
|
||||
# - The refactored offload decode now uses Attention.forward() with ring buffer
|
||||
# - Need specialized graph capture that sets up ring buffer correctly
|
||||
pass
|
||||
self.capture_offload_cudagraph()
|
||||
else:
|
||||
self.capture_cudagraph()
|
||||
torch.set_default_device("cpu")
|
||||
@@ -74,7 +67,10 @@ class ModelRunner:
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
if not self.enforce_eager:
|
||||
del self.graphs, self.graph_pool
|
||||
if hasattr(self, 'graphs'):
|
||||
del self.graphs, self.graph_pool
|
||||
if hasattr(self, 'offload_graphs'):
|
||||
del self.offload_graphs, self.offload_graph_pool
|
||||
# torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@@ -858,6 +854,7 @@ class ModelRunner:
|
||||
- Uses standard Attention.forward() path (not bypassing)
|
||||
- Per-layer decode buffer for accumulating new tokens
|
||||
- Async block offload when decode buffer is full
|
||||
- Uses CUDA graphs when available (not enforce_eager)
|
||||
"""
|
||||
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
||||
seq = seqs[0]
|
||||
@@ -867,9 +864,20 @@ class ModelRunner:
|
||||
num_layers = len(self.model.model.layers)
|
||||
num_buffers = offload_engine.num_kv_buffers
|
||||
|
||||
# Check if using CUDA graphs
|
||||
use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
|
||||
if use_cuda_graph:
|
||||
# Use fixed-address tensors for graph replay
|
||||
graph_vars = self.offload_graph_vars
|
||||
graph_vars["input_ids"][0] = seq.last_token
|
||||
graph_vars["positions"][0] = len(seq) - 1
|
||||
input_ids = graph_vars["input_ids"]
|
||||
positions = graph_vars["positions"]
|
||||
else:
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
|
||||
|
||||
# Get prefilled CPU blocks and compute valid tokens per block
|
||||
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
@@ -898,8 +906,14 @@ class ModelRunner:
|
||||
context_len = total_prefill_tokens + num_prev_decode_tokens
|
||||
|
||||
# Context setup for Attention.forward() - contiguous mode (no block tables)
|
||||
slot_mapping = torch.tensor([context_len], dtype=torch.int32, device="cuda")
|
||||
context_lens = torch.tensor([context_len + 1], dtype=torch.int32, device="cuda")
|
||||
if use_cuda_graph:
|
||||
graph_vars["slot_mapping"][0] = context_len
|
||||
graph_vars["context_lens"][0] = context_len + 1
|
||||
slot_mapping = graph_vars["slot_mapping"]
|
||||
context_lens = graph_vars["context_lens"]
|
||||
else:
|
||||
slot_mapping = torch.tensor([context_len], dtype=torch.int32, device="cuda")
|
||||
context_lens = torch.tensor([context_len + 1], dtype=torch.int32, device="cuda")
|
||||
|
||||
# Phase 1: Preload first N layers to ring buffer (fill pipeline)
|
||||
num_preload = min(num_buffers, num_layers)
|
||||
@@ -910,8 +924,14 @@ class ModelRunner:
|
||||
|
||||
# Step 1: Embedding (on compute stream)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
residual = None
|
||||
if use_cuda_graph:
|
||||
# Copy embedding output to graph's hidden_states
|
||||
embedded = self.model.model.embed_tokens(input_ids)
|
||||
graph_vars["hidden_states"].copy_(embedded)
|
||||
graph_vars["residual"].zero_() # Reset residual for first layer
|
||||
else:
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
# Phase 2: Layer-by-layer processing with ring buffer pipeline
|
||||
for layer_id in range(num_layers):
|
||||
@@ -947,12 +967,22 @@ class ModelRunner:
|
||||
block_tables=None, # Contiguous mode, no block tables
|
||||
)
|
||||
|
||||
# 2e. Forward through layer using standard path
|
||||
# This calls Qwen3Attention.forward() -> Attention.forward()
|
||||
# Attention.forward() will:
|
||||
# - Store new K,V to ring buffer via store_kvcache
|
||||
# - Compute attention via flash_attn_with_kvcache
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if use_cuda_graph:
|
||||
# 2e. Replay CUDA graph for this layer
|
||||
self.offload_graphs[layer_id].replay()
|
||||
# Synchronize to ensure graph completes before next operation
|
||||
torch.cuda.current_stream().synchronize()
|
||||
# Copy outputs to inputs for next layer
|
||||
if layer_id < num_layers - 1:
|
||||
graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
|
||||
graph_vars["residual"].copy_(graph_vars["layer_residual"])
|
||||
else:
|
||||
# 2e. Forward through layer using standard path (eager mode)
|
||||
# This calls Qwen3Attention.forward() -> Attention.forward()
|
||||
# Attention.forward() will:
|
||||
# - Store new K,V to ring buffer via store_kvcache
|
||||
# - Compute attention via flash_attn_with_kvcache
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
# 2f. Copy new token's KV from ring buffer to decode buffer (for persistence)
|
||||
# The new token was stored at position context_len in ring buffer
|
||||
@@ -972,7 +1002,12 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
# Step 3: Final norm
|
||||
hidden_states, _ = self.model.model.norm(hidden_states, residual)
|
||||
if use_cuda_graph:
|
||||
hidden_states, _ = self.model.model.norm(
|
||||
graph_vars["layer_outputs"], graph_vars["layer_residual"]
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = self.model.model.norm(hidden_states, residual)
|
||||
|
||||
# Step 4: Compute logits
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
@@ -1036,3 +1071,94 @@ class ModelRunner:
|
||||
block_tables=block_tables,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_offload_cudagraph(self):
|
||||
"""
|
||||
Capture CUDA graphs for offload decode using ring buffer.
|
||||
|
||||
Key design:
|
||||
- Captures per-layer graphs (not full decode)
|
||||
- Each layer's graph uses its corresponding ring buffer slot
|
||||
- H2D transfers happen outside the graph
|
||||
- Graph replays single layer forward pass
|
||||
|
||||
Ring buffer mapping: buffer_idx = layer_id % num_buffers
|
||||
"""
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
num_layers = len(self.model.model.layers)
|
||||
num_buffers = offload_engine.num_kv_buffers
|
||||
hf_config = self.config.hf_config
|
||||
|
||||
logger.info(f"Capturing offload CUDA graphs: {num_layers} layers, {num_buffers} buffers")
|
||||
|
||||
# Fixed-address tensors for graph capture (batch_size=1 for offload)
|
||||
input_ids = torch.zeros(1, dtype=torch.int64, device="cuda")
|
||||
positions = torch.zeros(1, dtype=torch.int64, device="cuda")
|
||||
slot_mapping = torch.zeros(1, dtype=torch.int32, device="cuda")
|
||||
context_lens = torch.ones(1, dtype=torch.int32, device="cuda") # At least 1 for valid attention
|
||||
hidden_states = torch.randn(1, hf_config.hidden_size, dtype=hf_config.torch_dtype, device="cuda")
|
||||
residual = torch.randn(1, hf_config.hidden_size, dtype=hf_config.torch_dtype, device="cuda")
|
||||
|
||||
# Per-layer outputs (hidden_states after each layer)
|
||||
layer_outputs = torch.zeros(1, hf_config.hidden_size, dtype=hf_config.torch_dtype, device="cuda")
|
||||
layer_residual = torch.zeros(1, hf_config.hidden_size, dtype=hf_config.torch_dtype, device="cuda")
|
||||
|
||||
self.offload_graphs = {}
|
||||
self.offload_graph_pool = None
|
||||
|
||||
# Capture per-layer graphs
|
||||
for layer_id in range(num_layers):
|
||||
buffer_idx = layer_id % num_buffers
|
||||
layer = self.model.model.layers[layer_id]
|
||||
attn_module = layer.self_attn.attn
|
||||
|
||||
# Set Attention cache to ring buffer (fixed address for this layer)
|
||||
attn_module.k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1]
|
||||
attn_module.v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1]
|
||||
|
||||
# Set context for contiguous mode (no block tables)
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=None,
|
||||
)
|
||||
|
||||
# Warmup run - execute layer and propagate state
|
||||
out_h, out_r = layer(positions, hidden_states, residual)
|
||||
layer_outputs.copy_(out_h)
|
||||
layer_residual.copy_(out_r)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture graph - use same input/output tensors
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.offload_graph_pool):
|
||||
out_h, out_r = layer(positions, hidden_states, residual)
|
||||
layer_outputs.copy_(out_h)
|
||||
layer_residual.copy_(out_r)
|
||||
|
||||
if self.offload_graph_pool is None:
|
||||
self.offload_graph_pool = graph.pool()
|
||||
|
||||
self.offload_graphs[layer_id] = graph
|
||||
reset_context()
|
||||
|
||||
# Update hidden_states and residual for next layer's capture
|
||||
# This ensures subsequent layers see realistic input distributions
|
||||
hidden_states.copy_(layer_outputs)
|
||||
residual.copy_(layer_residual)
|
||||
|
||||
# Store graph variables for replay
|
||||
self.offload_graph_vars = dict(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
layer_outputs=layer_outputs,
|
||||
layer_residual=layer_residual,
|
||||
)
|
||||
|
||||
logger.info(f"Captured {num_layers} offload CUDA graphs")
|
||||
|
||||
Reference in New Issue
Block a user