mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
argmax inside model + brevitas pin (#1872)
This commit is contained in:
@@ -442,15 +442,13 @@ class VicunaBase(SharkLLMBase):
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
_token = torch.tensor(output[0].to_host())
|
||||
|
||||
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"logits": _logits,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
@@ -1482,11 +1480,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
# Disabling this path of IR generation for now as it is broken.
|
||||
print("Please check if the mlir file is present at the shark tank. Exiting.")
|
||||
self.shark_model = None
|
||||
sys.exit()
|
||||
return
|
||||
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
@@ -1546,6 +1539,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(first_module))
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
@@ -1660,6 +1656,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path[:-5]+"_torch.mlir", "w+") as f:
|
||||
f.write(str(second_module))
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
@@ -1743,7 +1742,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, None, prefill_time
|
||||
@@ -1756,7 +1754,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
params = {
|
||||
"token": token,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"past_key_values": pkv,
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
@@ -1768,7 +1765,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
decode_time_ms = (time.time() - decode_st_time)*1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
|
||||
|
||||
@@ -43,7 +43,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
def forward(self, input_ids):
|
||||
op = self.model(input_ids=input_ids, use_cache=True)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
@@ -289,7 +291,8 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
token = torch.argmax(op.logits[:, -1, :], dim=1)
|
||||
return_vals.append(token)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
|
||||
@@ -47,4 +47,4 @@ pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
Reference in New Issue
Block a user