mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add the shark_downloader for the torch_models. (#184)
This commit is contained in:
@@ -1,41 +1,24 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.model.forward(x, y, z)[0]
|
||||
|
||||
|
||||
test_input = torch.randint(2, (1, 128)).to(torch.int32)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input, test_input, test_input),
|
||||
frontend="torch",
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
|
||||
# torch hugging face models needs tracing..
|
||||
mlir_importer.import_debug(tracing_required=True)
|
||||
|
||||
# print(golden_out)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
# shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
# shark_module.compile()
|
||||
# result = shark_module.forward((test_input, test_input, test_input))
|
||||
# print("Obtained result", result)
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = shark_module.generate_random_inputs()
|
||||
rand_results = shark_module.forward(rand_inputs)
|
||||
|
||||
print("Running shark_module with random_inputs is: ", rand_results)
|
||||
|
||||
@@ -27,6 +27,29 @@ input_type_to_np_dtype = {
|
||||
"int8": np.int8,
|
||||
}
|
||||
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(model_name):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
gs_command = (
|
||||
"gsutil cp -r gs://shark_tank" + "/" + model_name + " " + WORKDIR
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
with open(os.path.join(model_dir, model_name + ".mlir")) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
class SharkDownloader:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user