set optional vmfb download (#1667)

This commit is contained in:
Daniel Garvey
2023-07-18 12:57:28 -05:00
committed by GitHub
parent 8c317e4809
commit 8927cb0a2c

View File

@@ -101,6 +101,7 @@ parser.add_argument(
default=128,
help="Group size for per_group weight quantization. Default: 128.",
)
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
@@ -875,12 +876,14 @@ class UnshardedVicuna(SharkLLMBase):
load_mlir_from_shark_tank=True,
low_device_memory=False,
weight_group_size=128,
download_vmfb=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.download_vmfb = download_vmfb
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
@@ -1299,7 +1302,7 @@ class UnshardedVicuna(SharkLLMBase):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8"
and self.precision == "int8" and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
@@ -1321,7 +1324,7 @@ class UnshardedVicuna(SharkLLMBase):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8"
and self.precision == "int8" and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
@@ -1549,6 +1552,7 @@ if __name__ == "__main__":
second_vicuna_vmfb_path=second_vic_vmfb_path,
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
weight_group_size=args.weight_group_size,
download_vmfb=args.download_vmfb,
)
else:
if args.config is not None: