[WIP] NEED to modify communication.

This commit is contained in:
Zijie Tian
2025-12-24 21:57:51 +08:00
parent 782437c486
commit 6ec1b23982
9 changed files with 462 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
@@ -117,6 +118,9 @@ class Attention(nn.Module):
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
@@ -128,7 +132,6 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
current_chunk_idx = context.current_chunk_idx
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
@@ -170,6 +173,7 @@ class Attention(nn.Module):
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
@@ -177,13 +181,17 @@ class Attention(nn.Module):
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
if o_acc is None:
final_o = current_o
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -287,6 +295,8 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Alternate between slot_A and slot_B
current_slot = slot_A if block_idx % 2 == 0 else slot_B
next_slot = slot_B if block_idx % 2 == 0 else slot_A
@@ -300,12 +310,14 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
# Compute attention on current slot's data
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next round to safely load into this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id)
@@ -316,6 +328,8 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc
def _chunked_decode_attention(