[refactor] Refactor the kvcache offload.

This commit is contained in:
Zijie Tian
2026-01-04 19:37:03 +08:00
parent 00ed17c640
commit 772313db8f
3 changed files with 224 additions and 57 deletions

View File

@@ -61,8 +61,14 @@ class NanovllmSteppable(SteppableModel):
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
# hidden_states is MLP output, residual is accumulated residual
# To match torch reference, we need hidden_states + residual
if isinstance(output, tuple) and len(output) >= 2:
hidden_states, residual = output[0], output[1]
full_output = hidden_states + residual
else:
full_output = output
self._captured[f"layer_{idx}"] = full_output.detach().clone()
return hook
self._hooks.append(