mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
(LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
This commit is contained in:
@@ -432,7 +432,7 @@ class VicunaBase(SharkLLMBase):
|
||||
is_first=is_first,
|
||||
)
|
||||
else:
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
token = torch.tensor(token).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input, send_to_host=False
|
||||
@@ -444,7 +444,7 @@ class VicunaBase(SharkLLMBase):
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
_past_key_values = output[1:]
|
||||
_token = torch.tensor(output[0].to_host())
|
||||
_token = int(output[0].to_host())
|
||||
|
||||
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
|
||||
ret_dict = {
|
||||
@@ -1451,6 +1451,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name:
|
||||
self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir")
|
||||
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
|
||||
@@ -1687,7 +1689,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
combined_module = save_mlir(
|
||||
combined_module,
|
||||
model_name="combined_llama",
|
||||
mlir_dialect="tm_tensor"
|
||||
mlir_dialect="tm_tensor",
|
||||
dir=self.vicuna_mlir_path,
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
Reference in New Issue
Block a user