Files
nano-vllm/nanovllm/utils/memory_observer.py
Zijie Tian 39d12a0416 📈 feat: add MemoryObserver for GPU-CPU communication tracking
Implement MemoryObserver to track memory transfers between GPU and CPU:
- H2D (Host to Device): CPU → GPU transfers
- D2H (Device to Host): GPU → CPU transfers
- D2D (Device to Device): GPU buffer copies
- Supports prefill/decode phase separation

Integration points in offload_engine.py:
- load_to_slot_layer: H2D with is_prefill parameter
- offload_slot_layer_to_cpu, offload_prefill_buffer_async: D2H
- write_to_prefill_buffer, write_to_decode_buffer: D2D
- load_block_sample_from_cpu, load_block_full_from_cpu: H2D

Add bench_offload.py integration for memory stats printing.

Benchmark results (Llama-3.1-8B, 64K context):
- Full Policy: Prefill H2D 262.13 GB
- XAttention: Prefill H2D 386.62 GB (1.48x)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 04:06:45 +08:00

134 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MemoryObserver - 内存传输统计 Observer。
统计 GPU-CPU 间的数据传输量:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
"""
from nanovllm.utils.observer import Observer
class MemoryObserver(Observer):
"""
内存传输 Observer统计 GPU-CPU 间的数据传输量。
统计类型:
- H2D (Host to Device): CPU → GPU
- D2H (Device to Host): GPU → CPU
- D2D (Device to Device): GPU → GPU (buffer copy)
统计位置(均在 offload_engine.py
- H2D: load_to_slot_layer(), load_block_sample_from_cpu(), load_block_full_from_cpu()
- D2H: offload_slot_layer_to_cpu(), offload_prefill_buffer_async()
- D2D: write_to_prefill_buffer(), write_to_decode_buffer()
- 重置: llm_engine.py:generate() - 与 InferenceObserver 一起重置
"""
_enabled: bool = False # 默认禁用,需要显式启用
# H2D 统计
h2d_bytes: int = 0
h2d_count: int = 0
# D2H 统计
d2h_bytes: int = 0
d2h_count: int = 0
# D2D 统计
d2d_bytes: int = 0
d2d_count: int = 0
# 按阶段统计
prefill_h2d_bytes: int = 0
prefill_d2h_bytes: int = 0
decode_h2d_bytes: int = 0
decode_d2h_bytes: int = 0
@classmethod
def record_h2d(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 H2D 传输"""
if not cls._enabled:
return
cls.h2d_bytes += num_bytes
cls.h2d_count += 1
if is_prefill:
cls.prefill_h2d_bytes += num_bytes
else:
cls.decode_h2d_bytes += num_bytes
@classmethod
def record_d2h(cls, num_bytes: int, is_prefill: bool = True) -> None:
"""记录 D2H 传输"""
if not cls._enabled:
return
cls.d2h_bytes += num_bytes
cls.d2h_count += 1
if is_prefill:
cls.prefill_d2h_bytes += num_bytes
else:
cls.decode_d2h_bytes += num_bytes
@classmethod
def record_d2d(cls, num_bytes: int) -> None:
"""记录 D2D 传输"""
if not cls._enabled:
return
cls.d2d_bytes += num_bytes
cls.d2d_count += 1
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计"""
cls.h2d_bytes = cls.h2d_count = 0
cls.d2h_bytes = cls.d2h_count = 0
cls.d2d_bytes = cls.d2d_count = 0
cls.prefill_h2d_bytes = cls.prefill_d2h_bytes = 0
cls.decode_h2d_bytes = cls.decode_d2h_bytes = 0
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
return {
"total": {
"h2d_bytes": cls.h2d_bytes,
"h2d_count": cls.h2d_count,
"d2h_bytes": cls.d2h_bytes,
"d2h_count": cls.d2h_count,
"d2d_bytes": cls.d2d_bytes,
"d2d_count": cls.d2d_count,
},
"prefill": {
"h2d_bytes": cls.prefill_h2d_bytes,
"d2h_bytes": cls.prefill_d2h_bytes,
},
"decode": {
"h2d_bytes": cls.decode_h2d_bytes,
"d2h_bytes": cls.decode_d2h_bytes,
},
}
@classmethod
def _fmt_bytes(cls, b: int) -> str:
"""格式化字节数"""
if b >= 1e9:
return f"{b/1e9:.2f} GB"
if b >= 1e6:
return f"{b/1e6:.2f} MB"
if b >= 1e3:
return f"{b/1e3:.2f} KB"
return f"{b} B"
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要"""
fmt = cls._fmt_bytes
total = cls.h2d_bytes + cls.d2h_bytes + cls.d2d_bytes
print(f"[MemoryObserver] Total: {fmt(total)}")
print(f" H2D: {fmt(cls.h2d_bytes)} ({cls.h2d_count} ops)")
print(f" D2H: {fmt(cls.d2h_bytes)} ({cls.d2h_count} ops)")
print(f" D2D: {fmt(cls.d2d_bytes)} ({cls.d2d_count} ops)")
print(f" Prefill - H2D: {fmt(cls.prefill_h2d_bytes)}, D2H: {fmt(cls.prefill_d2h_bytes)}")
print(f" Decode - H2D: {fmt(cls.decode_h2d_bytes)}, D2H: {fmt(cls.decode_d2h_bytes)}")