mlir/vmfb path fixes for vic pipeline

This commit is contained in:
PhaneeshB
2023-07-11 19:18:44 +05:30
committed by Phaneesh Barwaria
parent 38e5b62d80
commit 6e8dbf72bd

View File

@@ -879,6 +879,7 @@ class UnshardedVicuna(SharkLLMBase):
weight_group_size=128,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
self.precision = precision
@@ -907,9 +908,9 @@ class UnshardedVicuna(SharkLLMBase):
def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = self.device.split("-")[0]
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(f"{model_number}_{self.model_name}_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
f"{model_number}_{self.model_name}_{self.precision}_{safe_device}.{suffix}"
)
def get_tokenizer(self):
@@ -951,7 +952,7 @@ class UnshardedVicuna(SharkLLMBase):
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
@@ -993,8 +994,7 @@ class UnshardedVicuna(SharkLLMBase):
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
is_f16=self.precision == "fp16", # TODO: Remove from import_with_fx args and fix all calls
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
@@ -1122,7 +1122,7 @@ class UnshardedVicuna(SharkLLMBase):
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)