mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix llama2-70b rewrite tensor dim
This commit is contained in:
@@ -1369,7 +1369,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
if "llama2_13b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x60x?x128x"
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if self.precision in ["fp16", "int4", "int8"]:
|
||||
|
||||
Reference in New Issue
Block a user