[WIP] NEED to modify communication.

This commit is contained in:
Zijie Tian
2025-12-24 21:57:51 +08:00
parent 782437c486
commit 6ec1b23982
9 changed files with 462 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility:
"""
import torch
import torch.cuda.nvtx
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
@@ -660,6 +661,7 @@ class OffloadEngine:
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
@@ -672,6 +674,7 @@ class OffloadEngine:
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
"""
@@ -718,6 +721,7 @@ class OffloadEngine:
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
@@ -727,6 +731,7 @@ class OffloadEngine:
self.v_cache_gpu[:, slot_idx], non_blocking=True
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""

View File

@@ -1,5 +1,6 @@
import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
@@ -117,6 +118,9 @@ class Attention(nn.Module):
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
@@ -128,7 +132,6 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
current_chunk_idx = context.current_chunk_idx
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
@@ -170,6 +173,7 @@ class Attention(nn.Module):
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
@@ -177,13 +181,17 @@ class Attention(nn.Module):
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
if o_acc is None:
final_o = current_o
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -287,6 +295,8 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Alternate between slot_A and slot_B
current_slot = slot_A if block_idx % 2 == 0 else slot_B
next_slot = slot_B if block_idx % 2 == 0 else slot_A
@@ -300,12 +310,14 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
# Compute attention on current slot's data
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next round to safely load into this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
@@ -316,6 +328,8 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc
def _chunked_decode_attention(