Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-06-20 22:34:17 +05:30
committed by GitHub
parent 855435ee24
commit 3fb72e192e
2 changed files with 13 additions and 4 deletions

View File

@@ -116,6 +116,7 @@ def compile_through_fx(
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
@@ -157,7 +158,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb: