Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-06-20 22:34:17 +05:30
committed by GitHub
parent 855435ee24
commit 3fb72e192e
2 changed files with 13 additions and 4 deletions

View File

@@ -116,6 +116,7 @@ def compile_through_fx(
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
@@ -157,7 +158,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:

View File

@@ -52,14 +52,22 @@ shark_module, _ = compile_through_fx(
base_model_id=None,
model_name="mega_shark",
precision=None,
return_mlir=False,
return_mlir=True,
device="cuda",
)
# logits = model(x)
def print_output_info(output, msg):
print("\n", msg)
print("\n\t", output.shape)
ans = shark_module("forward", input)
print(type(ans))
print("Logits : ", ans.shape)
print_output_info(torch.from_numpy(ans), "SHARK's output")
ans = megaModel.forward(*input)
print_output_info(ans, "ORIGINAL Model's output")
# and sample from the logits accordingly
# or you can use the generate function