[refactor] Refactor the kvcache offload.
This commit is contained in:
@@ -61,8 +61,14 @@ class NanovllmSteppable(SteppableModel):
|
||||
def make_layer_hook(idx):
|
||||
def hook(module, input, output):
|
||||
# Decoder layer returns (hidden_states, residual)
|
||||
hidden_states = output[0] if isinstance(output, tuple) else output
|
||||
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
|
||||
# hidden_states is MLP output, residual is accumulated residual
|
||||
# To match torch reference, we need hidden_states + residual
|
||||
if isinstance(output, tuple) and len(output) >= 2:
|
||||
hidden_states, residual = output[0], output[1]
|
||||
full_output = hidden_states + residual
|
||||
else:
|
||||
full_output = output
|
||||
self._captured[f"layer_{idx}"] = full_output.detach().clone()
|
||||
return hook
|
||||
|
||||
self._hooks.append(
|
||||
|
||||
Reference in New Issue
Block a user