[refactor] Refactor offload code to multi-chunk.
This commit is contained in:
@@ -17,6 +17,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
**ModelRunner** (`nanovllm/engine/model_runner.py`):
|
||||
- Loads model weights, allocates KV cache, captures CUDA graphs
|
||||
- Rank 0 is main process; ranks 1+ run via `loop()` with shared memory events
|
||||
- Chunked offload methods: `run_chunked_offload_prefill()`, `run_chunked_offload_decode()`
|
||||
|
||||
**Scheduler** (`nanovllm/engine/scheduler.py`):
|
||||
- Two-phase scheduling: prefill (waiting queue) then decode (running queue)
|
||||
@@ -34,7 +35,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
|
||||
**Global Context** (`nanovllm/utils/context.py`):
|
||||
- Stores attention metadata via `get_context()`/`set_context()`
|
||||
- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq`
|
||||
- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq`, `kvcache_manager`
|
||||
- `kvcache_manager`: Reference to HybridKVCacheManager for chunked attention (set when `is_chunked_prefill=True`)
|
||||
|
||||
## CPU Offload System
|
||||
|
||||
@@ -118,9 +120,10 @@ Compute: [C0] [C1] [C2]
|
||||
|
||||
Manages both GPU and CPU blocks:
|
||||
- `allocate()`: Allocate GPU block first, fallback to CPU
|
||||
- `allocate_cpu_only()`: Force CPU allocation (for chunked prefill)
|
||||
- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode)
|
||||
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
|
||||
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
|
||||
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
|
||||
- `may_offload()`: Offload GPU blocks to CPU when decode slot fills
|
||||
|
||||
### Online Softmax Merge
|
||||
|
||||
@@ -169,17 +169,17 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
ping_size = config.num_gpu_kvcache_blocks // 2
|
||||
tokens_per_ping = ping_size * self.block_size
|
||||
compute_size = config.num_gpu_kvcache_blocks // 2
|
||||
tokens_per_chunk = compute_size * self.block_size
|
||||
logger.info(
|
||||
f"KV Cache allocated (Ping-Pong mode): "
|
||||
f"KV Cache allocated (Chunked Offload mode): "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
|
||||
f"Total={total_memory_mb:.1f}MB"
|
||||
)
|
||||
logger.info(
|
||||
f"Ping-Pong config: ping_size={ping_size} blocks, "
|
||||
f"tokens_per_chunk={tokens_per_ping}, "
|
||||
f"Chunked Offload config: compute_size={compute_size} blocks, "
|
||||
f"tokens_per_chunk={tokens_per_chunk}, "
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
else:
|
||||
@@ -374,14 +374,14 @@ class ModelRunner:
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
# Check if Ping-Pong mode should be used (all blocks on CPU)
|
||||
# Check if Chunked Offload mode should be used (all blocks on CPU)
|
||||
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
||||
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
|
||||
if use_pingpong:
|
||||
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
|
||||
if use_chunked_offload:
|
||||
if is_prefill:
|
||||
return self.run_pingpong_prefill(seqs)
|
||||
return self.run_chunked_offload_prefill(seqs)
|
||||
else:
|
||||
return self.run_pingpong_decode(seqs)
|
||||
return self.run_chunked_offload_decode(seqs)
|
||||
|
||||
# Check if chunked prefill is needed (legacy path)
|
||||
if is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
@@ -410,7 +410,7 @@ class ModelRunner:
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
"""
|
||||
Check if three-region mode should be used.
|
||||
|
||||
@@ -553,7 +553,7 @@ class ModelRunner:
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
@@ -622,13 +622,13 @@ class ModelRunner:
|
||||
max_seqlen_k=seqlen,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=True,
|
||||
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
|
||||
kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV
|
||||
chunked_seq=seq, # Pass sequence for loading previous KV
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with three-region GPU buffer (CPU is primary storage).
|
||||
|
||||
@@ -679,7 +679,7 @@ class ModelRunner:
|
||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_pingpong_chunk(
|
||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
|
||||
)
|
||||
|
||||
@@ -718,7 +718,7 @@ class ModelRunner:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _prepare_pingpong_chunk(
|
||||
def _prepare_chunked_offload_chunk(
|
||||
self,
|
||||
seq: Sequence,
|
||||
chunk_start: int,
|
||||
@@ -726,7 +726,7 @@ class ModelRunner:
|
||||
gpu_slots: list[int],
|
||||
start_block_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare inputs for a Ping-Pong prefill chunk."""
|
||||
"""Prepare inputs for a chunked offload prefill chunk."""
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[chunk_start:chunk_end]
|
||||
positions = list(range(chunk_start, chunk_end))
|
||||
@@ -768,13 +768,13 @@ class ModelRunner:
|
||||
max_seqlen_k=seqlen,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=True,
|
||||
offload_engine=self.kvcache_manager,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with three-region GPU buffer.
|
||||
|
||||
@@ -809,7 +809,7 @@ class ModelRunner:
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
decode_pos_in_block=pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos,
|
||||
|
||||
@@ -336,7 +336,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Allocate logical blocks for prefill.
|
||||
|
||||
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
|
||||
In cpu_primary mode (Chunked Offload): All blocks are allocated to CPU.
|
||||
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
@@ -1167,9 +1167,9 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
return block.cpu_block_id
|
||||
return -1
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get GPU slot for writing new KV during three-region decode.
|
||||
Get GPU slot for writing new KV during chunked offload decode.
|
||||
|
||||
In three-region design, always use Decode region (slot 0) to write new KV.
|
||||
This avoids conflicts with Compute/Prefetch region loading operations.
|
||||
|
||||
@@ -91,12 +91,6 @@ class OffloadEngine:
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# Keep old ping/pong attributes for compatibility (will be removed later)
|
||||
self.ping_size = self.num_compute_blocks
|
||||
self.pong_size = self.num_prefetch_blocks
|
||||
self.ping_slots = self.compute_slots.copy()
|
||||
self.pong_slots = self.prefetch_slots.copy()
|
||||
|
||||
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
|
||||
@@ -148,13 +142,6 @@ class OffloadEngine:
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# Keep old ping/pong events for compatibility (will be removed later)
|
||||
self.pingpong_stream = self.transfer_stream_main
|
||||
self.ping_ready = self.compute_ready
|
||||
self.pong_ready = self.prefetch_ready
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-layer events for chunked attention ==========
|
||||
# Each layer has its own event for synchronization
|
||||
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
@@ -579,185 +566,9 @@ class OffloadEngine:
|
||||
f")"
|
||||
)
|
||||
|
||||
# ========== Ping-Pong double buffering methods ==========
|
||||
|
||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Ping buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.ping_size)
|
||||
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.ping_slots[i]
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
|
||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Pong buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.pong_size)
|
||||
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.pong_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
|
||||
def wait_ping(self) -> None:
|
||||
"""Wait for Ping buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.ping_ready)
|
||||
|
||||
def wait_pong(self) -> None:
|
||||
"""Wait for Pong buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.pong_ready)
|
||||
|
||||
def offload_buffer_to_cpu(
|
||||
self,
|
||||
buffer: str,
|
||||
cpu_block_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Async offload KV from buffer to CPU.
|
||||
|
||||
Args:
|
||||
buffer: "ping" or "pong"
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
"""
|
||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||
|
||||
if not cpu_block_ids:
|
||||
event.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(slots))
|
||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
# Wait for compute to complete
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
event.record(self.pingpong_stream)
|
||||
|
||||
def offload_slot_to_cpu(
|
||||
self,
|
||||
gpu_slot: int,
|
||||
cpu_block_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload a single GPU slot's KV to CPU.
|
||||
|
||||
Args:
|
||||
gpu_slot: GPU slot ID
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
|
||||
def wait_ping_offload_done(self) -> None:
|
||||
"""Wait for Ping buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.ping_offload_done)
|
||||
|
||||
def wait_pong_offload_done(self) -> None:
|
||||
"""Wait for Pong buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.pong_offload_done)
|
||||
|
||||
def wait_all_offload_done(self) -> None:
|
||||
"""Wait for all offload operations to complete."""
|
||||
self.pingpong_stream.synchronize()
|
||||
|
||||
def get_kv_for_ping_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of slots in Ping buffer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.ping_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_pong_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of slots in Pong buffer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.pong_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
self.transfer_stream_main.synchronize()
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
@@ -918,8 +729,6 @@ class OffloadEngine:
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# Also update old ping/pong slots for compatibility
|
||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -123,8 +123,7 @@ class Attention(nn.Module):
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Compute/Prefetch region
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
@@ -215,7 +214,7 @@ class Attention(nn.Module):
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
kvcache_manager = context.offload_engine
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
|
||||
@@ -21,7 +21,7 @@ class Context:
|
||||
# Current chunk's position offset (for causal mask)
|
||||
chunk_offset: int = 0
|
||||
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||
offload_engine: Any = None
|
||||
kvcache_manager: Any = None
|
||||
# Current layer's previous K/V chunks (loaded from CPU)
|
||||
# Set by model_runner before each layer's forward
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
@@ -33,14 +33,6 @@ class Context:
|
||||
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||
decode_start_pos_in_block: int = 0
|
||||
|
||||
# ========== Per-layer chunked attention state ==========
|
||||
# Whether chunked decode/prefill is currently active (for hooks to check)
|
||||
chunked_decode_active: bool = False
|
||||
# CPU block IDs for the current chunk being processed
|
||||
chunked_decode_chunk_ids: List[int] = field(default_factory=list)
|
||||
# Current chunk index being processed
|
||||
chunked_decode_current_chunk: int = 0
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
@@ -61,7 +53,7 @@ def set_context(
|
||||
is_chunked_prefill=False,
|
||||
prev_kv_ranges=None,
|
||||
chunk_offset=0,
|
||||
offload_engine=None,
|
||||
kvcache_manager=None,
|
||||
chunked_seq=None,
|
||||
decode_pos_in_block=0,
|
||||
decode_start_pos_in_block=0,
|
||||
@@ -79,7 +71,7 @@ def set_context(
|
||||
is_chunked_prefill=is_chunked_prefill,
|
||||
prev_kv_ranges=prev_kv_ranges or [],
|
||||
chunk_offset=chunk_offset,
|
||||
offload_engine=offload_engine,
|
||||
kvcache_manager=kvcache_manager,
|
||||
chunked_seq=chunked_seq,
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_p
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Prefill Test (Ping-Pong)")
|
||||
print(f"Chunked Prefill Test (Chunked Offload)")
|
||||
print(f"=" * 60)
|
||||
print(f" target_input_len: ~{input_len} tokens")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
@@ -111,7 +111,7 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_p
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Decode Test (Ping-Pong)")
|
||||
print(f"Chunked Decode Test (Chunked Offload)")
|
||||
print(f"=" * 60)
|
||||
print(f" target_input_len: ~{input_len} tokens")
|
||||
print(f" output_len: {output_len} tokens")
|
||||
|
||||
Reference in New Issue
Block a user