[WIP] move metadata to GPU.
This commit is contained in:
@@ -88,6 +88,7 @@ class SparsePolicy(ABC):
|
||||
head_dim: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize policy resources.
|
||||
@@ -102,6 +103,7 @@ class SparsePolicy(ABC):
|
||||
head_dim: Dimension per head
|
||||
num_cpu_blocks: Number of CPU blocks allocated
|
||||
dtype: Data type for tensors
|
||||
device: Device for metadata storage (GPU recommended for performance)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user