argmax inside model + brevitas pin (#1872)

This commit is contained in:
Daniel Garvey
2023-10-05 22:15:21 -05:00
committed by GitHub
parent 2a451fa0c7
commit 7abddd01ec
3 changed files with 13 additions and 14 deletions

View File

@@ -442,15 +442,13 @@ class VicunaBase(SharkLLMBase):
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
else:
_logits = torch.tensor(output[0].to_host())
_past_key_values = output[1:]
_token = torch.argmax(_logits[:, -1, :], dim=1)
_token = torch.tensor(output[0].to_host())
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
ret_dict = {
"token": _token,
"detok": _detok,
"logits": _logits,
"past_key_values": _past_key_values,
}
@@ -1482,11 +1480,6 @@ class UnshardedVicuna(VicunaBase):
if not mlir_generated:
print(f"[DEBUG] mlir not found")
# Disabling this path of IR generation for now as it is broken.
print("Please check if the mlir file is present at the shark tank. Exiting.")
self.shark_model = None
sys.exit()
return
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
@@ -1546,6 +1539,9 @@ class UnshardedVicuna(VicunaBase):
use_tracing=False,
verbose=False,
)
if self.cache_vicunas:
with open(first_model_path[:-5]+"_torch.mlir", "w+") as f:
f.write(str(first_module))
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
@@ -1660,6 +1656,9 @@ class UnshardedVicuna(VicunaBase):
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
if self.cache_vicunas:
with open(second_model_path[:-5]+"_torch.mlir", "w+") as f:
f.write(str(second_module))
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
@@ -1743,7 +1742,6 @@ class UnshardedVicuna(VicunaBase):
prefill_time = time.time() - prefill_st_time
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok, None, prefill_time
@@ -1756,7 +1754,6 @@ class UnshardedVicuna(VicunaBase):
params = {
"token": token,
"is_first": False,
"logits": logits,
"past_key_values": pkv,
"sv": self.shark_model,
}
@@ -1768,7 +1765,6 @@ class UnshardedVicuna(VicunaBase):
decode_time_ms = (time.time() - decode_st_time)*1000
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]

View File

@@ -43,7 +43,9 @@ class FirstVicuna(torch.nn.Module):
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
return_vals.append(op.logits)
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
@@ -289,7 +291,8 @@ class SecondVicuna7B(torch.nn.Module):
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])

View File

@@ -47,4 +47,4 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea