mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
add support passing iree flags for LLMs
This commit is contained in:
committed by
Phaneesh Barwaria
parent
531d447768
commit
4f61d69d86
@@ -123,7 +123,12 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
@@ -160,11 +165,13 @@ class VicunaBase(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cpu",
|
||||
precision="int8",
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.extra_args = extra_args_cmd
|
||||
|
||||
def get_tokenizer(self):
|
||||
# Retrieve the tokenizer from Huggingface
|
||||
@@ -433,8 +440,9 @@ class ShardedVicuna(VicunaBase):
|
||||
config_json=None,
|
||||
weight_group_size=128,
|
||||
compressed=False,
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
@@ -940,7 +948,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
] + self.extra_args,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -1008,7 +1016,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
] + self.extra_args,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -1220,8 +1228,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
weight_group_size=128,
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||
@@ -1604,7 +1613,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
],
|
||||
] + self.extra_args,
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
@@ -1683,6 +1692,13 @@ class UnshardedVicuna(VicunaBase):
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
vic_mlir_path = (
|
||||
@@ -1706,6 +1722,7 @@ if __name__ == "__main__":
|
||||
weight_group_size=args.weight_group_size,
|
||||
download_vmfb=args.download_vmfb,
|
||||
cache_vicunas=args.cache_vicunas,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
@@ -1720,6 +1737,7 @@ if __name__ == "__main__":
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
|
||||
@@ -180,6 +180,15 @@ def chat(
|
||||
print("unrecognized device")
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
@@ -188,6 +197,7 @@ def chat(
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
@@ -198,6 +208,7 @@ def chat(
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
# if config_file is not None:
|
||||
|
||||
Reference in New Issue
Block a user