mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix vicuna script (#1745)
This commit is contained in:
@@ -419,44 +419,6 @@ class VicunaBase(SharkLLMBase):
|
||||
|
||||
return ret_dict
|
||||
|
||||
def generate_new_token(self, params):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
# crop input_ids
|
||||
# input_ids = input_ids[len(input_ids) - 20 :]
|
||||
############
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(input_ids, is_first=is_first)
|
||||
else:
|
||||
token = params["token"]
|
||||
past_key_values = params["past_key_values"]
|
||||
input_ids = [token]
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
output = self.shark_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=is_first
|
||||
)
|
||||
|
||||
_logits = output["logits"]
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
_detok = self.tokenizer.decode(_token)
|
||||
|
||||
ret_dict = {
|
||||
"token": _token,
|
||||
"detok": _detok,
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
|
||||
class ShardedVicuna(VicunaBase):
|
||||
# Class representing Sharded Vicuna Model
|
||||
@@ -976,7 +938,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
@@ -1044,7 +1006,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
@@ -1640,7 +1602,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
@@ -1792,8 +1754,9 @@ if __name__ == "__main__":
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
device=args.device,
|
||||
devices=args.device,
|
||||
precision=args.precision,
|
||||
config_file=None,
|
||||
cli=args.cli,
|
||||
)
|
||||
)[0]
|
||||
|
||||
Reference in New Issue
Block a user