Files
nano-vllm/nanovllm/kvcache/sparse/streaming_llm.py
2025-12-22 08:51:02 +08:00

85 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
StreamingLLM sparse attention policy.
Only keeps sink tokens (beginning) + recent tokens (end).
Intermediate context is discarded. This enables infinite-length
generation but loses intermediate context.
Reference: StreamingLLM paper on attention sinks.
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class StreamingLLMConfig:
"""Configuration for StreamingLLMPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (attention sinks)."""
num_recent_blocks: int = 3
"""Number of most recent blocks to include (sliding window)."""
class StreamingLLMPolicy(SparsePolicy):
"""
StreamingLLM pattern: sink tokens + recent tokens only.
This is the most aggressive sparsity pattern - only keeps a small
fixed window of context. Suitable for:
- Very long streaming generation
- When intermediate context can be safely discarded
- Maximizing throughput over accuracy
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
× × × ↑ ↑ ↑
sink (discarded) recent window
```
Warning: This loses information from intermediate blocks!
Use only when this trade-off is acceptable.
"""
def __init__(self, config: StreamingLLMConfig = None):
self.config = config or StreamingLLMConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + recent blocks only.
Intermediate blocks are not loaded (effectively discarded).
"""
n = len(available_blocks)
# If total blocks fit in sink + recent, load all
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
if n <= total_keep:
return available_blocks
selected_indices = set()
# Sink blocks (first N)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Recent blocks (last M)
for i in range(max(0, n - self.config.num_recent_blocks), n):
selected_indices.add(i)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
f"recent={self.config.num_recent_blocks})"
)