Add the shark_downloader for the torch_models. (#184)

This commit is contained in:
Prashant Kumar
2022-07-15 02:11:43 +05:30
committed by GitHub
parent 8434c67d96
commit 1191f53c9d
2 changed files with 40 additions and 34 deletions

View File

@@ -1,41 +1,24 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
from shark.shark_downloader import download_torch_model
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, x, y, z):
return self.model.forward(x, y, z)[0]
test_input = torch.randint(2, (1, 128)).to(torch.int32)
mlir_importer = SharkImporter(
MiniLMSequenceClassification(),
(test_input, test_input, test_input),
frontend="torch",
mlir_model, func_name, inputs, golden_out = download_torch_model(
"microsoft/MiniLM-L12-H384-uncased"
)
# torch hugging face models needs tracing..
mlir_importer.import_debug(tracing_required=True)
# print(golden_out)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)
print("The golden result is:", golden_out)
# shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
# shark_module.compile()
# result = shark_module.forward((test_input, test_input, test_input))
# print("Obtained result", result)
# Let's generate random inputs, currently supported
# for static models.
rand_inputs = shark_module.generate_random_inputs()
rand_results = shark_module.forward(rand_inputs)
print("Running shark_module with random_inputs is: ", rand_results)

View File

@@ -27,6 +27,29 @@ input_type_to_np_dtype = {
"int8": np.int8,
}
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
# Downloads the torch model from gs://shark_tank dir.
def download_torch_model(model_name):
model_name = model_name.replace("/", "_")
os.makedirs(WORKDIR, exist_ok=True)
gs_command = (
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
)
if os.system(gs_command) != 0:
raise Exception("model not present in the tank. Contact Nod Admin")
model_dir = os.path.join(WORKDIR, model_name)
with open(os.path.join(model_dir, model_name + ".mlir")) as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
class SharkDownloader:
def __init__(