📈 feat: add MemoryObserver for GPU-CPU communication tracking

Implement MemoryObserver to track memory transfers between GPU and CPU:
- H2D (Host to Device): CPU → GPU transfers
- D2H (Device to Host): GPU → CPU transfers
- D2D (Device to Device): GPU buffer copies
- Supports prefill/decode phase separation

Integration points in offload_engine.py:
- load_to_slot_layer: H2D with is_prefill parameter
- offload_slot_layer_to_cpu, offload_prefill_buffer_async: D2H
- write_to_prefill_buffer, write_to_decode_buffer: D2D
- load_block_sample_from_cpu, load_block_full_from_cpu: H2D

Add bench_offload.py integration for memory stats printing.

Benchmark results (Llama-3.1-8B, 64K context):
- Full Policy: Prefill H2D 262.13 GB
- XAttention: Prefill H2D 386.62 GB (1.48x)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-28 04:06:45 +08:00
parent c16bfcf40f
commit 39d12a0416
8 changed files with 458 additions and 3 deletions

View File

@@ -31,6 +31,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) | | [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
| [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 | | [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 |
| [`docs/xattn_performance_analysis.md`](docs/xattn_performance_analysis.md) | 📊 XAttention 性能分析: NVTX 标记、block size 影响、estimate vs compute 耗时对比 | | [`docs/xattn_performance_analysis.md`](docs/xattn_performance_analysis.md) | 📊 XAttention 性能分析: NVTX 标记、block size 影响、estimate vs compute 耗时对比 |
| [`docs/observer_architecture.md`](docs/observer_architecture.md) | 📊 Observer 架构: InferenceObserver (TTFT/TPOT)、MemoryObserver (H2D/D2H/D2D) 设计 |
| [`docs/memory_communication_benchmark.md`](docs/memory_communication_benchmark.md) | 📊 通信量测试: Full vs XAttention 通信量对比 (32K/64K)、阶段分离统计 |
## Rules Index ## Rules Index

View File

@@ -3,6 +3,14 @@ import time
from random import randint, seed from random import randint, seed
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.observer import InferenceObserver from nanovllm.utils.observer import InferenceObserver
from nanovllm.utils.memory_observer import MemoryObserver
def print_memory_stats():
"""Print MemoryObserver communication statistics"""
fmt = MemoryObserver._fmt_bytes
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}, D2H: {fmt(MemoryObserver.prefill_d2h_bytes)}")
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}, D2H: {fmt(MemoryObserver.decode_d2h_bytes)}")
def bench_decode(llm, num_seqs, input_len, output_len): def bench_decode(llm, num_seqs, input_len, output_len):
@@ -26,6 +34,7 @@ def bench_decode(llm, num_seqs, input_len, output_len):
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s") print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms") print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms")
print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)") print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)")
print_memory_stats()
def bench_prefill(llm, num_seqs, input_len): def bench_prefill(llm, num_seqs, input_len):
@@ -51,6 +60,7 @@ def bench_prefill(llm, num_seqs, input_len):
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})") print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})")
print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s") print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s")
print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s") print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s")
print_memory_stats()
def main(): def main():
@@ -88,6 +98,9 @@ def main():
path = os.path.expanduser(args.model) path = os.path.expanduser(args.model)
max_len = args.max_len max_len = args.max_len
# Enable MemoryObserver for communication stats
MemoryObserver._enabled = True
# Setup policy configuration # Setup policy configuration
if args.enable_quest: if args.enable_quest:
sparse_policy = SparsePolicyType.QUEST sparse_policy = SparsePolicyType.QUEST

View File

@@ -0,0 +1,82 @@
# Memory Communication Benchmark
GPU-CPU 通信量测试结果,对比 Full Policy 和 XAttention BSA Policy。
## 测试环境
- **模型**: Llama-3.1-8B-Instruct
- **GPU**: RTX 3090 (24GB)
- **配置**: `num_gpu_blocks=4`, `block_size=1024`, `enable_cpu_offload=True`
- **XAttention 参数**: `threshold=0.95`, `stride=8`
## 32K 上下文测试结果
| 指标 | Full Policy | XAttention | 比率 |
|------|-------------|------------|------|
| **Prefill H2D** | 66.57 GB | 111.12 GB | **1.67x** |
| Prefill D2H | 4.29 GB | 4.29 GB | 1.00x |
| TTFT | 8473 ms | 10367 ms | 1.22x |
### XAttention Block Selection (32K)
| 指标 | 数值 |
|------|------|
| 可用 blocks | 465 |
| 选中 blocks | 374 |
| 选择密度 | 80.4% |
## 64K 上下文测试结果
| 指标 | Full Policy | XAttention | 比率 |
|------|-------------|------------|------|
| **Prefill H2D** | 262.13 GB | 386.62 GB | **1.48x** |
| Prefill D2H | 8.46 GB | 8.46 GB | 1.00x |
| Decode H2D (32 tokens) | 262.13 GB | 262.13 GB | 1.00x |
| TTFT | 27081 ms | 33634 ms | 1.24x |
## 通信量比率对比
| 上下文长度 | XAttn/Full Prefill H2D 比率 |
|------------|----------------------------|
| 32K | 1.67x |
| 64K | 1.48x |
### 分析
1. **XAttention 通信量增加原因**
- Estimate 阶段:加载 **100%** 历史 blocks用于 attention score 估计)
- Compute 阶段:加载 **选中的** blocks约 70-80%
- 理论比率:`1 + selection_density`
2. **64K 比率更低的原因**
- 更长上下文时attention 分布更稀疏
- XAttention 的 block 选择更有效(选中比例更低)
- First/last block 强制包含的影响相对减小
3. **Decode 阶段通信量相同**
- XAttention 仅支持 prefill 阶段
- Decode 阶段 fallback 到 Full Policy
## 测试命令
```bash
# 32K Full Policy
python bench_offload.py --max-len 32768 --input-len 32000
# 32K XAttention
python bench_offload.py --max-len 32768 --input-len 32000 --enable-xattn
# 64K Full Policy
python bench_offload.py --max-len 65536 --input-len 64000
# 64K XAttention
python bench_offload.py --max-len 65536 --input-len 64000 --enable-xattn
# 包含 decode 测试
python bench_offload.py --max-len 65536 --input-len 64000 --bench-decode --output-len 32
```
## 相关文档
- [`observer_architecture.md`](observer_architecture.md) - Observer 架构设计
- [`xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md) - XAttention BSA 算法设计

View File

@@ -0,0 +1,194 @@
# Observer Architecture
nanovllm 的 Observer 架构用于统计推理过程中的关键指标采用类变量class variable模式实现全局状态管理。
## 架构概览
```
Observer (基类)
├── InferenceObserver - 推理时间指标 (TTFT, TPOT)
└── MemoryObserver - 内存传输统计 (H2D, D2H, D2D)
```
## 设计原则
### 1. 类变量模式
所有 Observer 使用类变量(而非实例变量)存储状态:
```python
class Observer:
"""Observer 基类"""
_enabled: bool = True # 类变量,控制是否启用
class InferenceObserver(Observer):
ttft: int = 0 # 类变量,全局共享
tpot: int = 0
ttft_start: int = 0
tpot_start: int = 0
```
**优点**
- 无需实例化,任何地方都可以直接访问
- 避免跨模块传递 observer 实例
- 适合全局统计场景
### 2. 启用/禁用控制
每个 Observer 可独立启用/禁用:
```python
# 启用 MemoryObserver
MemoryObserver._enabled = True
# 禁用后record_* 方法不会记录
MemoryObserver._enabled = False
```
### 3. 阶段分离
MemoryObserver 支持 prefill/decode 阶段分离统计:
```python
@classmethod
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
if not cls._enabled:
return
cls.h2d_bytes += num_bytes
cls.h2d_count += 1
if is_prefill:
cls.prefill_h2d_bytes += num_bytes
else:
cls.decode_h2d_bytes += num_bytes
```
## Observer 实现
### InferenceObserver
**位置**: `nanovllm/utils/observer.py`
**统计指标**
| 指标 | 说明 | 单位 |
|------|------|------|
| `ttft` | Time To First Token | 纳秒 |
| `tpot` | Time Per Output Token | 纳秒 |
| `ttft_start` | TTFT 计时开始点 | 纳秒 |
| `tpot_start` | TPOT 计时开始点 | 纳秒 |
**统计位置**
| 位置 | 代码 | 说明 |
|------|------|------|
| `scheduler.py:add()` | `InferenceObserver.ttft_start = perf_counter_ns()` | 开始计时 |
| `llm_engine.py:step()` | `InferenceObserver.ttft = ... - ttft_start` | Prefill 完成后计算 TTFT |
| `llm_engine.py:step()` | `InferenceObserver.tpot = ... - tpot_start` | Decode 时计算 TPOT |
### MemoryObserver
**位置**: `nanovllm/utils/memory_observer.py`
**统计指标**
| 指标 | 说明 |
|------|------|
| `h2d_bytes` / `h2d_count` | Host to Device 传输量/次数 |
| `d2h_bytes` / `d2h_count` | Device to Host 传输量/次数 |
| `d2d_bytes` / `d2d_count` | Device to Device 复制量/次数 |
| `prefill_h2d_bytes` / `prefill_d2h_bytes` | Prefill 阶段 H2D/D2H |
| `decode_h2d_bytes` / `decode_d2h_bytes` | Decode 阶段 H2D/D2H |
**统计位置** (均在 `offload_engine.py`)
| 方法 | 传输类型 | 说明 |
|------|----------|------|
| `load_to_slot_layer()` | H2D | 从 CPU 加载 block 到 GPU slot |
| `load_block_sample_from_cpu()` | H2D | 采样加载Quest |
| `load_block_full_from_cpu()` | H2D | 完整加载 block |
| `offload_slot_layer_to_cpu()` | D2H | GPU slot 卸载到 CPU |
| `offload_prefill_buffer_async()` | D2H | Prefill buffer 异步卸载 |
| `write_to_prefill_buffer()` | D2D | 写入 prefill buffer |
| `write_to_decode_buffer()` | D2D | 写入 decode buffer |
**重置位置**
| 位置 | 代码 |
|------|------|
| `llm_engine.py:generate()` | `MemoryObserver.complete_reset()` |
| `llm_engine.py:generate()` | `InferenceObserver.complete_reset()` |
## 使用示例
### 1. 启用并统计
```python
from nanovllm.utils.memory_observer import MemoryObserver
# 启用统计
MemoryObserver._enabled = True
# 运行推理
outputs = llm.generate(prompts, sampling_params)
# 获取结果
print(f"Prefill H2D: {MemoryObserver.prefill_h2d_bytes / 1e9:.2f} GB")
print(f"Decode H2D: {MemoryObserver.decode_h2d_bytes / 1e9:.2f} GB")
# 或使用 print_summary
MemoryObserver.print_summary()
```
### 2. 在 bench_offload.py 中
```python
from nanovllm.utils.memory_observer import MemoryObserver
# 启用
MemoryObserver._enabled = True
# benchmark 结束后打印
def print_memory_stats():
fmt = MemoryObserver._fmt_bytes
print(f"[Memory] Prefill H2D: {fmt(MemoryObserver.prefill_h2d_bytes)}")
print(f" Decode H2D: {fmt(MemoryObserver.decode_h2d_bytes)}")
```
### 3. 获取结构化数据
```python
summary = MemoryObserver.get_summary()
# {
# "total": {"h2d_bytes": ..., "d2h_bytes": ..., "d2d_bytes": ...},
# "prefill": {"h2d_bytes": ..., "d2h_bytes": ...},
# "decode": {"h2d_bytes": ..., "d2h_bytes": ...}
# }
```
## 添加新 Observer
1. 继承 `Observer` 基类
2. 定义类变量存储统计数据
3. 实现 `record_*` 方法(需检查 `_enabled`
4. 实现 `complete_reset()` 方法
5. 在相关代码位置添加 `record_*` 调用
6.`llm_engine.py:generate()` 中添加 reset 调用
```python
from nanovllm.utils.observer import Observer
class MyObserver(Observer):
_enabled: bool = False
my_metric: int = 0
@classmethod
def record_event(cls, value: int) -> None:
if not cls._enabled:
return
cls.my_metric += value
@classmethod
def complete_reset(cls) -> None:
cls.my_metric = 0
```
## 相关文档
- [`memory_communication_benchmark.md`](memory_communication_benchmark.md) - 通信量测试结果
- [`architecture_guide.md`](architecture_guide.md) - 整体架构指南

View File

@@ -11,6 +11,7 @@ from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner from nanovllm.engine.model_runner import ModelRunner
from nanovllm.utils.observer import InferenceObserver from nanovllm.utils.observer import InferenceObserver
from nanovllm.utils.memory_observer import MemoryObserver
class LLMEngine: class LLMEngine:
@@ -95,6 +96,7 @@ class LLMEngine:
debug_enabled = log_level.upper() == 'DEBUG' debug_enabled = log_level.upper() == 'DEBUG'
InferenceObserver.complete_reset() InferenceObserver.complete_reset()
MemoryObserver.complete_reset()
if use_tqdm: if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list): if not isinstance(sampling_params, list):

View File

@@ -17,6 +17,7 @@ from dataclasses import dataclass
from nanovllm.kvcache.kernels import gathered_copy_kv from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.comm import memcpy_2d_async from nanovllm.comm import memcpy_2d_async
from nanovllm.utils.logger import get_logger from nanovllm.utils.logger import get_logger
from nanovllm.utils.memory_observer import MemoryObserver
# Import for type hints only (avoid circular import) # Import for type hints only (avoid circular import)
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -376,7 +377,8 @@ class OffloadEngine:
self.ring_slot_compute_done[slot_idx].record() self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer( def load_to_slot_layer(
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1 self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1,
is_prefill: bool = True,
) -> None: ) -> None:
""" """
Async load a single CPU block to a ring buffer slot for one layer. Async load a single CPU block to a ring buffer slot for one layer.
@@ -393,6 +395,7 @@ class OffloadEngine:
layer_id: Layer index to load (for CPU cache indexing) layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID cpu_block_id: Source CPU block ID
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified) chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
is_prefill: True if in prefill phase, False if in decode phase (for MemoryObserver)
""" """
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
@@ -425,6 +428,9 @@ class OffloadEngine:
self.ring_slot_ready[slot_idx].record(stream) self.ring_slot_ready[slot_idx].record(stream)
nvtx.pop_range() nvtx.pop_range()
# Record H2D transfer: K + V = 2 * block_bytes
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=is_prefill)
def wait_slot_layer(self, slot_idx: int) -> None: def wait_slot_layer(self, slot_idx: int) -> None:
""" """
Wait for a slot's loading to complete. Wait for a slot's loading to complete.
@@ -499,6 +505,9 @@ class OffloadEngine:
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main) self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
nvtx.pop_range() nvtx.pop_range()
# Record D2H transfer: K + V = 2 * block_bytes
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=is_prefill)
# ----- KV access methods for ring buffer ----- # ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]: def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
@@ -745,6 +754,10 @@ class OffloadEngine:
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v) self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# Record D2D transfer: K + V
transfer_bytes = 2 * k.numel() * k.element_size()
MemoryObserver.record_d2d(transfer_bytes)
def write_to_decode_buffer( def write_to_decode_buffer(
self, self,
layer_id: int, layer_id: int,
@@ -768,6 +781,10 @@ class OffloadEngine:
self.decode_v_buffer[layer_id, pos_in_block].copy_(v) self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# Record D2D transfer: K + V (single token)
transfer_bytes = 2 * k.numel() * k.element_size()
MemoryObserver.record_d2d(transfer_bytes)
def offload_prefill_buffer_async( def offload_prefill_buffer_async(
self, self,
layer_id: int, layer_id: int,
@@ -813,6 +830,9 @@ class OffloadEngine:
self.prefill_offload_events[layer_id].record(stream) self.prefill_offload_events[layer_id].record(stream)
nvtx.pop_range() nvtx.pop_range()
# Record D2H transfer: K + V = 2 * block_bytes
MemoryObserver.record_d2h(2 * self.gpu_block_bytes, is_prefill=True)
def wait_all_prefill_offloads(self) -> None: def wait_all_prefill_offloads(self) -> None:
"""Wait for all prefill buffer offloads to complete.""" """Wait for all prefill buffer offloads to complete."""
for stream in self.prefill_offload_streams: for stream in self.prefill_offload_streams:
@@ -851,6 +871,11 @@ class OffloadEngine:
v_sample = self.v_cache_cpu[ v_sample = self.v_cache_cpu[
layer_id, cpu_block_id, :num_samples layer_id, cpu_block_id, :num_samples
].clone().cuda() ].clone().cuda()
# Record H2D transfer: K + V samples
transfer_bytes = 2 * k_sample.numel() * k_sample.element_size()
MemoryObserver.record_h2d(transfer_bytes, is_prefill=True)
return k_sample, v_sample return k_sample, v_sample
def load_block_full_from_cpu( def load_block_full_from_cpu(
@@ -877,4 +902,8 @@ class OffloadEngine:
v_full = self.v_cache_cpu[ v_full = self.v_cache_cpu[
layer_id, cpu_block_id layer_id, cpu_block_id
].clone().cuda() ].clone().cuda()
# Record H2D transfer: K + V full block
MemoryObserver.record_h2d(2 * self.gpu_block_bytes, is_prefill=True)
return k_full, v_full return k_full, v_full

View File

@@ -422,7 +422,7 @@ class FullAttentionPolicy(SparsePolicy):
num_preload = min(num_slots, num_blocks) num_preload = min(num_slots, num_blocks)
for i in range(num_preload): for i in range(num_preload):
cpu_block_id = cpu_block_table[i] cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id, is_prefill=False)
# Phase 2: Process blocks with pipeline # Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks): for block_idx in range(num_blocks):
@@ -456,7 +456,7 @@ class FullAttentionPolicy(SparsePolicy):
next_block_idx = block_idx + num_slots next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks: if next_block_idx < num_blocks:
next_cpu_block_id = cpu_block_table[next_block_idx] next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id, is_prefill=False)
# Merge with accumulated # Merge with accumulated
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):

View File

@@ -0,0 +1,133 @@
"""
MemoryObserver - 内存传输统计 Observer。
统计 GPU-CPU 间的数据传输量:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
"""
from nanovllm.utils.observer import Observer
class MemoryObserver(Observer):
"""
内存传输 Observer统计 GPU-CPU 间的数据传输量。
统计类型:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
统计位置(均在 offload_engine.py
- H2D: load_to_slot_layer(), load_block_sample_from_cpu(), load_block_full_from_cpu()
- D2H: offload_slot_layer_to_cpu(), offload_prefill_buffer_async()
- D2D: write_to_prefill_buffer(), write_to_decode_buffer()
- 重置: llm_engine.py:generate() - 与 InferenceObserver 一起重置
"""
_enabled: bool = False # 默认禁用,需要显式启用
# H2D 统计
h2d_bytes: int = 0
h2d_count: int = 0
# D2H 统计
d2h_bytes: int = 0
d2h_count: int = 0
# D2D 统计
d2d_bytes: int = 0
d2d_count: int = 0
# 按阶段统计
prefill_h2d_bytes: int = 0
prefill_d2h_bytes: int = 0
decode_h2d_bytes: int = 0
decode_d2h_bytes: int = 0
@classmethod
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 H2D 传输"""
if not cls._enabled:
return
cls.h2d_bytes += num_bytes
cls.h2d_count += 1
if is_prefill:
cls.prefill_h2d_bytes += num_bytes
else:
cls.decode_h2d_bytes += num_bytes
@classmethod
def record_d2h(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 D2H 传输"""
if not cls._enabled:
return
cls.d2h_bytes += num_bytes
cls.d2h_count += 1
if is_prefill:
cls.prefill_d2h_bytes += num_bytes
else:
cls.decode_d2h_bytes += num_bytes
@classmethod
def record_d2d(cls, num_bytes: int) -> None:
"""记录 D2D 传输"""
if not cls._enabled:
return
cls.d2d_bytes += num_bytes
cls.d2d_count += 1
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计"""
cls.h2d_bytes = cls.h2d_count = 0
cls.d2h_bytes = cls.d2h_count = 0
cls.d2d_bytes = cls.d2d_count = 0
cls.prefill_h2d_bytes = cls.prefill_d2h_bytes = 0
cls.decode_h2d_bytes = cls.decode_d2h_bytes = 0
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
return {
"total": {
"h2d_bytes": cls.h2d_bytes,
"h2d_count": cls.h2d_count,
"d2h_bytes": cls.d2h_bytes,
"d2h_count": cls.d2h_count,
"d2d_bytes": cls.d2d_bytes,
"d2d_count": cls.d2d_count,
},
"prefill": {
"h2d_bytes": cls.prefill_h2d_bytes,
"d2h_bytes": cls.prefill_d2h_bytes,
},
"decode": {
"h2d_bytes": cls.decode_h2d_bytes,
"d2h_bytes": cls.decode_d2h_bytes,
},
}
@classmethod
def _fmt_bytes(cls, b: int) -> str:
"""格式化字节数"""
if b >= 1e9:
return f"{b/1e9:.2f} GB"
if b >= 1e6:
return f"{b/1e6:.2f} MB"
if b >= 1e3:
return f"{b/1e3:.2f} KB"
return f"{b} B"
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要"""
fmt = cls._fmt_bytes
total = cls.h2d_bytes + cls.d2h_bytes + cls.d2d_bytes
print(f"[MemoryObserver] Total: {fmt(total)}")
print(f" H2D: {fmt(cls.h2d_bytes)} ({cls.h2d_count} ops)")
print(f" D2H: {fmt(cls.d2h_bytes)} ({cls.d2h_count} ops)")
print(f" D2D: {fmt(cls.d2d_bytes)} ({cls.d2d_count} ops)")
print(f" Prefill - H2D: {fmt(cls.prefill_h2d_bytes)}, D2H: {fmt(cls.prefill_d2h_bytes)}")
print(f" Decode - H2D: {fmt(cls.decode_h2d_bytes)}, D2H: {fmt(cls.decode_d2h_bytes)}")