[WIP] FIXED decode and prefill NEEDLE test.

This commit is contained in:
Zijie Tian
2026-01-05 01:51:46 +08:00
parent e897380127
commit d623043a3c
3 changed files with 204 additions and 28 deletions

View File

@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()
self.warmup_model()
#> Disable warmup for debugging
# self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
@@ -194,7 +197,7 @@ class ModelRunner:
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
#> Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):

View File

@@ -23,15 +23,19 @@ parser = argparse.ArgumentParser()
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (ring buffer slots)")
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size")
args = parser.parse_args()
# Config
MODEL_PATH = os.path.expanduser(args.model_path)
INPUT_LEN = args.input_len
ENABLE_OFFLOAD = args.enable_offload
NUM_GPU_BLOCKS = args.num_gpu_blocks
BLOCK_SIZE = args.block_size
DTYPE = torch.float16
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}")
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}, num_gpu_blocks={NUM_GPU_BLOCKS}, block_size={BLOCK_SIZE}")
# Storage for captured tensors
nanovllm_outputs = {}
@@ -41,6 +45,9 @@ nanovllm_proj_inputs = {}
torch_proj_inputs = {}
# ============================================================
# Hook functions for non-offload mode (overwrite)
# ============================================================
def make_nanovllm_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
@@ -70,6 +77,70 @@ def make_proj_input_hook(layer_id: int, storage: dict):
return hook
# ============================================================
# Hook functions for offload mode (accumulate Q and I, overwrite O)
# ============================================================
def make_accumulating_q_hook(layer_id: int, storage: dict):
"""Accumulate Q from each chunk for offload mode."""
def hook(module, inputs):
q = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(q)
return hook
def make_accumulating_input_hook(layer_id: int, storage: dict):
"""Accumulate input hidden states from each chunk for offload mode."""
def hook(module, inputs):
hidden = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(hidden)
return hook
def make_overwrite_output_hook(layer_id: int, storage: dict):
"""Overwrite output (keep only last chunk) for offload mode."""
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# ============================================================
# CPU KV cache access for offload mode
# ============================================================
def get_nanovllm_kv_from_cpu(llm, seq, num_layers):
"""Get complete K, V cache from CPU side after all chunks finish."""
offload_engine = llm.model_runner.kvcache_manager.offload_engine
kvcache_manager = llm.model_runner.kvcache_manager
# CRITICAL: Synchronize all CUDA operations before reading CPU memory
# The D2H copy runs on transfer_stream_main and may still be in progress
torch.cuda.synchronize()
cpu_block_ids = kvcache_manager.get_cpu_block_table(seq)
kv_per_layer = {}
for layer_id in range(num_layers):
k_blocks = []
v_blocks = []
for cpu_block_id in cpu_block_ids:
k_block, v_block = offload_engine.get_cpu_block(layer_id, cpu_block_id)
k_blocks.append(k_block)
v_blocks.append(v_block)
# Concatenate all blocks: [total_tokens, kv_heads, head_dim]
k_full = torch.cat(k_blocks, dim=0)[:seq.num_tokens]
v_full = torch.cat(v_blocks, dim=0)[:seq.num_tokens]
kv_per_layer[layer_id] = {"k": k_full, "v": v_full}
return kv_per_layer
def make_torch_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
storage[layer_id] = output[0].detach().clone()
@@ -101,15 +172,18 @@ def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
# Load models
# ============================================================
print("Loading nanovllm model...")
llm = LLM(
MODEL_PATH,
llm_kwargs = dict(
enforce_eager=True,
max_model_len=32768,
gpu_memory_utilization=0.2,
max_num_batched_tokens=32768,
enable_cpu_offload=ENABLE_OFFLOAD,
dtype="float16",
kvcache_block_size=BLOCK_SIZE,
)
if ENABLE_OFFLOAD:
llm_kwargs["num_gpu_blocks"] = NUM_GPU_BLOCKS
llm = LLM(MODEL_PATH, **llm_kwargs)
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
@@ -133,10 +207,20 @@ print(f"Input shape: {input_ids.shape}")
# Register hooks
# ============================================================
nanovllm_hooks = []
nanovllm_q_accum = {} # For offload mode: accumulated Q from all chunks
nanovllm_i_accum = {} # For offload mode: accumulated I from all chunks
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
if ENABLE_OFFLOAD:
# Offload mode: accumulate Q and I, overwrite O
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_overwrite_output_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_accumulating_q_hook(layer_idx, nanovllm_q_accum)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_accumulating_input_hook(layer_idx, nanovllm_i_accum)))
else:
# Non-offload mode: overwrite all
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
torch_hooks = []
for layer_idx, layer in enumerate(torch_model.model.layers):
@@ -147,7 +231,36 @@ for layer_idx, layer in enumerate(torch_model.model.layers):
# Run inference
# ============================================================
print("Running nanovllm inference...")
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
if ENABLE_OFFLOAD:
# Manual execution to capture KV cache before deallocation
# Use max_tokens=2 so sequence doesn't finish immediately after prefill
llm.add_request(input_ids[0].tolist(), SamplingParams(temperature=0.01, max_tokens=2))
# Run prefill step (this calls run_chunked_offload_prefill internally)
output, num_tokens = llm.step()
print(f"[Offload] Prefill done: {num_tokens} tokens")
# Now seq is in running queue, KV cache is in CPU
seq = llm.scheduler.running[0]
print(f"[Offload] Sequence: {seq}")
# Get KV cache from CPU BEFORE decode step deallocates it
nanovllm_kv_cpu = get_nanovllm_kv_from_cpu(llm, seq, num_layers)
print(f"[Offload] Retrieved KV cache from CPU for {seq.num_tokens} tokens")
# IMPORTANT: Save outputs NOW before decode step overwrites them
# nanovllm_outputs contains prefill attention outputs at this point
nanovllm_outputs_prefill = {k: v.clone() for k, v in nanovllm_outputs.items()}
# Complete remaining steps (decode)
while not llm.is_finished():
llm.step()
# Use prefill outputs for comparison
nanovllm_outputs = nanovllm_outputs_prefill
else:
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
print("Running torch inference...")
with torch.no_grad():
@@ -164,24 +277,72 @@ all_passed = True
threshold = 0.999 # Cosine similarity threshold
for layer_idx in range(num_layers):
# Input similarity
nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
if ENABLE_OFFLOAD:
# ============================================================
# Offload mode: use accumulated Q/I and CPU-side K/V
# Only compare prompt tokens (INPUT_LEN), exclude generated tokens
# ============================================================
# I: concatenate accumulated chunks, trim to prompt length
i_chunks = nanovllm_i_accum[layer_idx]
nano_in = torch.cat(i_chunks, dim=0)[:INPUT_LEN]
if nano_in.dim() == 2:
nano_in = nano_in.unsqueeze(0)
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# QKV similarities
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# Q: concatenate accumulated chunks, trim to prompt length
q_chunks = nanovllm_q_accum[layer_idx]
nano_q = torch.cat(q_chunks, dim=0)[:INPUT_LEN]
torch_q = torch_qkv_outputs[layer_idx]["q"].squeeze(0).transpose(0, 1)
q_sim = cosine_sim(nano_q, torch_q)
# O similarity
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape)
o_sim = cosine_sim(nano_out, torch_out)
# K, V: from CPU cache, trim to prompt length and move to GPU
nano_k = nanovllm_kv_cpu[layer_idx]["k"][:INPUT_LEN].cuda()
torch_k = torch_qkv_outputs[layer_idx]["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
k_sim = cosine_sim(nano_k, torch_k)
nano_v = nanovllm_kv_cpu[layer_idx]["v"][:INPUT_LEN].cuda()
torch_v = torch_qkv_outputs[layer_idx]["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
v_sim = cosine_sim(nano_v, torch_v)
# O: compare attention outputs directly
# For single-chunk case (input_len <= block_size), shapes should match
# For multi-chunk case, nano_out is the last chunk only
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.numel() == torch_out.numel():
# Single chunk or shapes match - compare directly
o_sim = cosine_sim(nano_out, torch_out)
else:
# Multi-chunk case: compare last chunk with corresponding torch slice
last_chunk_len = nano_out.shape[1] if nano_out.dim() == 3 else nano_out.shape[0]
torch_out_slice = torch_out[:, -last_chunk_len:, :] if torch_out.dim() == 3 else torch_out[-last_chunk_len:, :]
o_sim = cosine_sim(nano_out, torch_out_slice)
else:
# ============================================================
# Non-offload mode: original logic
# ============================================================
# Input similarity
nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# QKV similarities
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# O similarity
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape)
o_sim = cosine_sim(nano_out, torch_out)
# Check pass/fail
passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
@@ -197,7 +358,8 @@ for hook in nanovllm_hooks + torch_hooks:
hook.remove()
print("=" * 70)
mode_str = " [offload]" if ENABLE_OFFLOAD else ""
if all_passed:
print("test_align: PASSED (cosine_sim >= 0.999)")
print(f"test_align{mode_str}: PASSED (cosine_sim >= 0.999)")
else:
print("test_align: FAILED (* = cosine_sim < 0.999)")
print(f"test_align{mode_str}: FAILED (* = cosine_sim < 0.999)")

View File

@@ -24,6 +24,7 @@ def run_needle_test(
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
@@ -38,6 +39,7 @@ def run_needle_test(
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
@@ -54,6 +56,7 @@ def run_needle_test(
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
@@ -65,6 +68,7 @@ def run_needle_test(
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
@@ -119,7 +123,7 @@ if __name__ == "__main__":
parser.add_argument(
"--max-model-len",
type=int,
default=32 * 1024,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
@@ -134,6 +138,12 @@ if __name__ == "__main__":
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
@@ -164,6 +174,7 @@ if __name__ == "__main__":
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,