[WIP] move metadata to GPU.

This commit is contained in:
Zijie Tian
2026-01-06 23:32:32 +08:00
parent edb5273e34
commit 0e691f2d85
3 changed files with 35 additions and 16 deletions

View File

@@ -88,6 +88,7 @@ class SparsePolicy(ABC):
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
device: torch.device = None,
) -> None:
"""
Initialize policy resources.
@@ -102,6 +103,7 @@ class SparsePolicy(ABC):
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
device: Device for metadata storage (GPU recommended for performance)
"""
pass