[WIP] NEED to modify communication.
This commit is contained in:
@@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility:
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.cuda.nvtx
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
@@ -660,6 +661,7 @@ class OffloadEngine:
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
@@ -672,6 +674,7 @@ class OffloadEngine:
|
||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""
|
||||
@@ -718,6 +721,7 @@ class OffloadEngine:
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
@@ -727,6 +731,7 @@ class OffloadEngine:
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||
"""Wait for slot offload to complete."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import torch
|
||||
import torch.cuda.nvtx
|
||||
from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -117,6 +118,9 @@ class Attention(nn.Module):
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
@@ -128,7 +132,6 @@ class Attention(nn.Module):
|
||||
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
@@ -170,6 +173,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -177,13 +181,17 @@ class Attention(nn.Module):
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
@@ -287,6 +295,8 @@ class Attention(nn.Module):
|
||||
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
|
||||
|
||||
# Alternate between slot_A and slot_B
|
||||
current_slot = slot_A if block_idx % 2 == 0 else slot_B
|
||||
next_slot = slot_B if block_idx % 2 == 0 else slot_A
|
||||
@@ -300,12 +310,14 @@ class Attention(nn.Module):
|
||||
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
|
||||
|
||||
# Compute attention on current slot's data
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Record compute done - this allows the next round to safely load into this slot
|
||||
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
|
||||
@@ -316,6 +328,8 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
torch.cuda.nvtx.range_pop() # PipelineBlock
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _chunked_decode_attention(
|
||||
|
||||
Reference in New Issue
Block a user