[feat] Added num_gpu_blocks limit gpu blocks.
This commit is contained in:
@@ -245,5 +245,5 @@ class Attention(nn.Module):
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return o_acc.squeeze(1)
|
||||
# Output shape: [batch, 1, heads, dim] (same as normal decode)
|
||||
return o_acc
|
||||
|
||||
Reference in New Issue
Block a user