(LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)

- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
This commit is contained in:
Ean Garvey
2023-10-09 14:37:35 -05:00
committed by GitHub
parent 9f0a421764
commit 3b825579a7

View File

@@ -432,7 +432,7 @@ class VicunaBase(SharkLLMBase):
is_first=is_first,
)
else:
token = token.to(torch.int64).reshape([1, 1])
token = torch.tensor(token).reshape([1, 1])
second_input = (token,) + tuple(past_key_values)
output = self.shark_model(
"second_vicuna_forward", second_input, send_to_host=False
@@ -444,7 +444,7 @@ class VicunaBase(SharkLLMBase):
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
else:
_past_key_values = output[1:]
_token = torch.tensor(output[0].to_host())
_token = int(output[0].to_host())
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
ret_dict = {
@@ -1451,6 +1451,8 @@ class UnshardedVicuna(VicunaBase):
mlir_generated = False
for suffix in ["mlirbc", "mlir"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name:
self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir")
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
@@ -1687,7 +1689,7 @@ class UnshardedVicuna(VicunaBase):
combined_module = save_mlir(
combined_module,
model_name="combined_llama",
mlir_dialect="tm_tensor"
mlir_dialect="tm_tensor",
dir=self.vicuna_mlir_path,
)
del first_module, second_module