[WIP] fixing attention compute error.
This commit is contained in:
@@ -93,9 +93,9 @@ TEST_CASES = [
|
||||
(1, 4, 256, 8, 128),
|
||||
(1, 4, 512, 8, 128),
|
||||
(1, 8, 512, 8, 128),
|
||||
(1, 4, 1024, 8, 128),
|
||||
(1, 4, 1024, 32, 128), # More heads
|
||||
(1, 8, 256, 8, 64), # Smaller head dim
|
||||
(1, 32, 1024, 8, 128),
|
||||
(1, 32, 1024, 32, 128), # More heads
|
||||
(1, 32, 256, 8, 64), # Smaller head dim
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
Reference in New Issue
Block a user