[WEB] Cache the compiled module.

-- Don't compile the module again and again.
This commit is contained in:
Prashant Kumar
2022-09-01 08:13:30 -07:00
parent a886cba655
commit 885b0969f5
2 changed files with 33 additions and 19 deletions

View File

@@ -6,6 +6,7 @@ import numpy as np
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
COMPILE_MODULE = None
class AlbertModule(torch.nn.Module):
@@ -54,18 +55,23 @@ def top5_possibilities(text, inputs, token_logits):
def albert_maskfill_inf(masked_text):
global COMPILE_MODULE
inputs = preprocess_data(masked_text)
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
if COMPILE_MODULE == None:
print("module compiled")
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg", device="intel-gpu"
)
shark_module.compile()
COMPILE_MODULE = shark_module
token_logits = torch.tensor(COMPILE_MODULE.forward(inputs))
return top5_possibilities(masked_text, inputs, token_logits)

View File

@@ -7,6 +7,8 @@ from shark.shark_downloader import download_torch_model
################################## Preprocessing inputs and model ############
COMPILE_MODULE = None
def preprocess_image(img):
image = Image.fromarray(img)
@@ -49,13 +51,19 @@ def top3_possibilities(res):
def resnet_inf(numpy_img):
img = preprocess_image(numpy_img)
## Can pass any img or input to the forward module.
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet50"
)
global COMPILE_MODULE
if COMPILE_MODULE == None:
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet50"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((img.detach().numpy(),))
shark_module = SharkInference(
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
)
shark_module.compile()
COMPILE_MODULE = shark_module
result = COMPILE_MODULE.forward((img.detach().numpy(),))
# print("The top 3 results obtained via shark_runner is:")
return top3_possibilities(torch.from_numpy(result))