From 3fb72e192ecd36734208fd31d2d27bc4efbd199b Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 20 Jun 2023 22:34:17 +0530 Subject: [PATCH] Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559) -- It also modifies the mega_test.py script Signed-off-by: Abhishek Varma --- apps/stable_diffusion/src/utils/utils.py | 3 ++- shark/examples/shark_inference/mega_test.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 54199f76..85d943d9 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -116,6 +116,7 @@ def compile_through_fx( model_name=None, precision=None, return_mlir=False, + device=None, ): if not return_mlir and model_name is not None: vmfb_path = get_vmfb_path_name(extended_model_name) @@ -157,7 +158,7 @@ def compile_through_fx( shark_module = SharkInference( mlir_module, - device=args.device, + device=args.device if device is None else device, mlir_dialect="tm_tensor", ) if generate_vmfb: diff --git a/shark/examples/shark_inference/mega_test.py b/shark/examples/shark_inference/mega_test.py index de0cc242..efc5e70b 100644 --- a/shark/examples/shark_inference/mega_test.py +++ b/shark/examples/shark_inference/mega_test.py @@ -52,14 +52,22 @@ shark_module, _ = compile_through_fx( base_model_id=None, model_name="mega_shark", precision=None, - return_mlir=False, + return_mlir=True, device="cuda", ) # logits = model(x) + +def print_output_info(output, msg): + print("\n", msg) + print("\n\t", output.shape) + + ans = shark_module("forward", input) -print(type(ans)) -print("Logits : ", ans.shape) +print_output_info(torch.from_numpy(ans), "SHARK's output") + +ans = megaModel.forward(*input) +print_output_info(ans, "ORIGINAL Model's output") # and sample from the logits accordingly # or you can use the generate function