mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
mlir/vmfb path fixes for vic pipeline
This commit is contained in:
committed by
Phaneesh Barwaria
parent
38e5b62d80
commit
6e8dbf72bd
@@ -879,6 +879,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
weight_group_size=128,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
@@ -907,9 +908,9 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
def get_model_path(self, model_number="first", suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix == "mlir":
|
||||
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
|
||||
return Path(f"{model_number}_{self.model_name}_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
|
||||
f"{model_number}_{self.model_name}_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -951,7 +952,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
if self.precision in ["fp32", "fp16", "int8", "int4"]:
|
||||
# download MLIR from shark_tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
|
||||
self.first_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -993,8 +994,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision
|
||||
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
is_f16=self.precision == "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
@@ -1122,7 +1122,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
if self.precision in ["fp32", "fp16", "int8", "int4"]:
|
||||
# download MLIR from shark_tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
|
||||
self.second_vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user