mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can be kept on device. For vulkan int4, this gives 44 tok/s (for the first prompt) and settles at around 26 tok/s on 7900xtx.
This commit is contained in:
@@ -382,8 +382,7 @@ class VicunaBase(SharkLLMBase):
|
||||
if sharded:
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
output = self.shark_model("first_vicuna_forward", (input_ids,))
|
||||
out_tensor = torch.tensor(output[1:])
|
||||
output = self.shark_model("first_vicuna_forward", (input_ids,), send_to_host=False)
|
||||
|
||||
else:
|
||||
token = params["token"]
|
||||
@@ -402,7 +401,7 @@ class VicunaBase(SharkLLMBase):
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input
|
||||
"second_vicuna_forward", second_input, send_to_host=False
|
||||
)
|
||||
|
||||
if sharded:
|
||||
@@ -410,8 +409,8 @@ class VicunaBase(SharkLLMBase):
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
_logits = torch.tensor(output[0])
|
||||
_past_key_values = torch.tensor(output[1:])
|
||||
_logits = torch.tensor(output[0].to_host())
|
||||
_past_key_values = output[1:]
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
|
||||
_detok = self.tokenizer.decode(_token, skip_special_tokens=False)
|
||||
|
||||
Reference in New Issue
Block a user