mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
set optional vmfb download (#1667)
This commit is contained in:
@@ -101,6 +101,7 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Group size for per_group weight quantization. Default: 128.",
|
||||
)
|
||||
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
@@ -875,12 +876,14 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
load_mlir_from_shark_tank=True,
|
||||
low_device_memory=False,
|
||||
weight_group_size=128,
|
||||
download_vmfb=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.download_vmfb = download_vmfb
|
||||
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
|
||||
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
|
||||
self.first_vicuna_mlir_path = first_vicuna_mlir_path
|
||||
@@ -1299,7 +1302,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8"
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
@@ -1321,7 +1324,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8"
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
@@ -1549,6 +1552,7 @@ if __name__ == "__main__":
|
||||
second_vicuna_vmfb_path=second_vic_vmfb_path,
|
||||
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
|
||||
weight_group_size=args.weight_group_size,
|
||||
download_vmfb=args.download_vmfb,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
|
||||
Reference in New Issue
Block a user