fix precision for fp16

This commit is contained in:
PhaneeshB
2023-07-10 19:03:44 +05:30
committed by Phaneesh Barwaria
parent a517e217b0
commit be417f0bf4

View File

@@ -977,6 +977,8 @@ class UnshardedVicuna(SharkLLMBase):
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
@@ -1137,6 +1139,7 @@ class UnshardedVicuna(SharkLLMBase):
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",