[WIP] FIXED decode and prefill NEEDLE test.

This commit is contained in:
Zijie Tian
2026-01-05 01:51:46 +08:00
parent e897380127
commit d623043a3c
3 changed files with 204 additions and 28 deletions

View File

@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()
self.warmup_model()
#> Disable warmup for debugging
# self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
@@ -194,7 +197,7 @@ class ModelRunner:
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
#> Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):