From 7abddd01ec73f288096ea3bbe3576f8911bcfef9 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 5 Oct 2023 22:15:21 -0500 Subject: [PATCH] argmax inside model + brevitas pin (#1872) --- apps/language_models/scripts/vicuna.py | 18 +++++++----------- .../src/model_wrappers/vicuna_model.py | 7 +++++-- requirements.txt | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index f3cf596b..9cc0b509 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -442,15 +442,13 @@ class VicunaBase(SharkLLMBase): _past_key_values = output["past_key_values"] _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) else: - _logits = torch.tensor(output[0].to_host()) _past_key_values = output[1:] - _token = torch.argmax(_logits[:, -1, :], dim=1) + _token = torch.tensor(output[0].to_host()) _detok = self.tokenizer.decode(_token, skip_special_tokens=False) ret_dict = { "token": _token, "detok": _detok, - "logits": _logits, "past_key_values": _past_key_values, } @@ -1482,11 +1480,6 @@ class UnshardedVicuna(VicunaBase): if not mlir_generated: print(f"[DEBUG] mlir not found") - # Disabling this path of IR generation for now as it is broken. - print("Please check if the mlir file is present at the shark tank. Exiting.") - self.shark_model = None - sys.exit() - return print("[DEBUG] generating mlir on device") # Select a compilation prompt such that the resulting input_ids @@ -1546,6 +1539,9 @@ class UnshardedVicuna(VicunaBase): use_tracing=False, verbose=False, ) + if self.cache_vicunas: + with open(first_model_path[:-5]+"_torch.mlir", "w+") as f: + f.write(str(first_module)) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( first_module, @@ -1660,6 +1656,9 @@ class UnshardedVicuna(VicunaBase): verbose=False, ) print(f"[DEBUG] converting torch to linalg") + if self.cache_vicunas: + with open(second_model_path[:-5]+"_torch.mlir", "w+") as f: + f.write(str(second_module)) run_pipeline_with_repro_report( second_module, "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", @@ -1743,7 +1742,6 @@ class UnshardedVicuna(VicunaBase): prefill_time = time.time() - prefill_st_time token = generated_token_op["token"] - logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] yield detok, None, prefill_time @@ -1756,7 +1754,6 @@ class UnshardedVicuna(VicunaBase): params = { "token": token, "is_first": False, - "logits": logits, "past_key_values": pkv, "sv": self.shark_model, } @@ -1768,7 +1765,6 @@ class UnshardedVicuna(VicunaBase): decode_time_ms = (time.time() - decode_st_time)*1000 token = generated_token_op["token"] - logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index e857ef96..fa06e468 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -43,7 +43,9 @@ class FirstVicuna(torch.nn.Module): def forward(self, input_ids): op = self.model(input_ids=input_ids, use_cache=True) return_vals = [] - return_vals.append(op.logits) + token = torch.argmax(op.logits[:, -1, :], dim=1) + return_vals.append(token) + temp_past_key_values = op.past_key_values for item in temp_past_key_values: return_vals.append(item[0]) @@ -289,7 +291,8 @@ class SecondVicuna7B(torch.nn.Module): input_ids=token, use_cache=True, past_key_values=past_key_values ) return_vals = [] - return_vals.append(op.logits) + token = torch.argmax(op.logits[:, -1, :], dim=1) + return_vals.append(token) temp_past_key_values = op.past_key_values for item in temp_past_key_values: return_vals.append(item[0]) diff --git a/requirements.txt b/requirements.txt index 3abc8ba6..e294d8fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,4 +47,4 @@ pefile pyinstaller # vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@dev +brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea