[vicuna.py] Keep past key values on device (#1836)

The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
This commit is contained in:
Quinn Dawkins
2023-09-19 18:17:41 -04:00
committed by GitHub
parent 79267931c1
commit ded74d09cd

View File

@@ -382,8 +382,7 @@ class VicunaBase(SharkLLMBase):
if sharded:
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
output = self.shark_model("first_vicuna_forward", (input_ids,))
out_tensor = torch.tensor(output[1:])
output = self.shark_model("first_vicuna_forward", (input_ids,), send_to_host=False)
else:
token = params["token"]
@@ -402,7 +401,7 @@ class VicunaBase(SharkLLMBase):
token = token.to(torch.int64).reshape([1, 1])
second_input = (token,) + tuple(past_key_values)
output = self.shark_model(
"second_vicuna_forward", second_input
"second_vicuna_forward", second_input, send_to_host=False
)
if sharded:
@@ -410,8 +409,8 @@ class VicunaBase(SharkLLMBase):
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
else:
_logits = torch.tensor(output[0])
_past_key_values = torch.tensor(output[1:])
_logits = torch.tensor(output[0].to_host())
_past_key_values = output[1:]
_token = torch.argmax(_logits[:, -1, :], dim=1)
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)