fix llama2-70b rewrite tensor dim

This commit is contained in:
Phaneesh Barwaria
2023-09-01 14:02:59 +05:30
parent 4c3d8a0a7f
commit 1ccafa1fc1

View File

@@ -1369,7 +1369,7 @@ class UnshardedVicuna(VicunaBase):
if "llama2_13b" in self.model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in self.model_name:
pkv_tensor_shape = "tensor<1x60x?x128x"
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if self.precision in ["fp16", "int4", "int8"]: