mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
@@ -116,6 +116,7 @@ def compile_through_fx(
|
||||
model_name=None,
|
||||
precision=None,
|
||||
return_mlir=False,
|
||||
device=None,
|
||||
):
|
||||
if not return_mlir and model_name is not None:
|
||||
vmfb_path = get_vmfb_path_name(extended_model_name)
|
||||
@@ -157,7 +158,7 @@ def compile_through_fx(
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
device=args.device if device is None else device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
if generate_vmfb:
|
||||
|
||||
@@ -52,14 +52,22 @@ shark_module, _ = compile_through_fx(
|
||||
base_model_id=None,
|
||||
model_name="mega_shark",
|
||||
precision=None,
|
||||
return_mlir=False,
|
||||
return_mlir=True,
|
||||
device="cuda",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
|
||||
def print_output_info(output, msg):
|
||||
print("\n", msg)
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", input)
|
||||
print(type(ans))
|
||||
print("Logits : ", ans.shape)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
|
||||
ans = megaModel.forward(*input)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
Reference in New Issue
Block a user