mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fix precision for fp16
This commit is contained in:
committed by
Phaneesh Barwaria
parent
a517e217b0
commit
be417f0bf4
@@ -977,6 +977,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision
|
||||
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
@@ -1137,6 +1139,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
mlir_type="torchscript",
|
||||
|
||||
Reference in New Issue
Block a user