mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[WEB] Cache the compiled module.
-- Don't compile the module again and again.
This commit is contained in:
@@ -6,6 +6,7 @@ import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
COMPILE_MODULE = None
|
||||
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
@@ -54,18 +55,23 @@ def top5_possibilities(text, inputs, token_logits):
|
||||
|
||||
|
||||
def albert_maskfill_inf(masked_text):
|
||||
global COMPILE_MODULE
|
||||
inputs = preprocess_data(masked_text)
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
if COMPILE_MODULE == None:
|
||||
print("module compiled")
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg", device="intel-gpu"
|
||||
)
|
||||
shark_module.compile()
|
||||
COMPILE_MODULE = shark_module
|
||||
|
||||
token_logits = torch.tensor(COMPILE_MODULE.forward(inputs))
|
||||
return top5_possibilities(masked_text, inputs, token_logits)
|
||||
|
||||
@@ -7,6 +7,8 @@ from shark.shark_downloader import download_torch_model
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
|
||||
COMPILE_MODULE = None
|
||||
|
||||
|
||||
def preprocess_image(img):
|
||||
image = Image.fromarray(img)
|
||||
@@ -49,13 +51,19 @@ def top3_possibilities(res):
|
||||
def resnet_inf(numpy_img):
|
||||
img = preprocess_image(numpy_img)
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"resnet50"
|
||||
)
|
||||
global COMPILE_MODULE
|
||||
if COMPILE_MODULE == None:
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"resnet50"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
COMPILE_MODULE = shark_module
|
||||
|
||||
result = COMPILE_MODULE.forward((img.detach().numpy(),))
|
||||
|
||||
# print("The top 3 results obtained via shark_runner is:")
|
||||
return top3_possibilities(torch.from_numpy(result))
|
||||
|
||||
Reference in New Issue
Block a user