mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 15:28:10 -05:00
Compare commits
2 Commits
main
...
ean-dynamo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
493f776253 | ||
|
|
6a82667778 |
@@ -90,3 +90,8 @@ def pytest_addoption(parser):
|
||||
type=int,
|
||||
help="Batch size for the tested model.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--custom_device",
|
||||
default=None,
|
||||
help="Custom device string to run tests with.",
|
||||
)
|
||||
|
||||
133
shark/sharkdynamo/shark_backend.py
Normal file
133
shark/sharkdynamo/shark_backend.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
var_id = 0
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def shark_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
):
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
global var_id
|
||||
with open(f"linalg_gen_{var_id}.mlir", "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(mlir_module)
|
||||
print("saving!")
|
||||
var_id += 1
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.tensor(x)
|
||||
else:
|
||||
result = tuple(torch.tensor(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
@@ -34,6 +34,11 @@ hf_seq2seq_models = [
|
||||
]
|
||||
|
||||
|
||||
def get_training_model(modelname, import_args):
|
||||
if "bert" in modelname:
|
||||
return get_bert_pretrain_model(modelname, import_args)
|
||||
|
||||
|
||||
def get_torch_model(modelname, import_args):
|
||||
if modelname in vision_models:
|
||||
return get_vision_model(modelname, import_args)
|
||||
@@ -47,14 +52,31 @@ def get_torch_model(modelname, import_args):
|
||||
return get_hf_model(modelname, import_args)
|
||||
|
||||
|
||||
##################### Hugging Face BERT PreTraining Models #######################################
|
||||
|
||||
|
||||
def get_bert_pretrain_model(model_name, import_args):
|
||||
from transformers import BertForPreTraining
|
||||
import copy
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_model = BertForPreTraining.from_pretrained(model_name)
|
||||
base_model = base_model.train()
|
||||
my_config = copy.deepcopy(base_model.config)
|
||||
my_config.num_hidden_layers = import_args["num_hidden_layers"]
|
||||
my_config.num_attention_heads = import_args["num_attention_heads"]
|
||||
my_config.hidden_size = import_args["hidden_size"]
|
||||
my_config.vocab_size = import_args["vocab_size"]
|
||||
|
||||
return BertForPreTraining(my_config)
|
||||
|
||||
|
||||
##################### Hugging Face Image Classification Models ###################################
|
||||
from transformers import AutoModelForImageClassification
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
|
||||
def preprocess_input_image(model_name):
|
||||
from PIL import Image
|
||||
|
||||
# from datasets import load_dataset
|
||||
# dataset = load_dataset("huggingface/cats-image")
|
||||
# image1 = dataset["test"]["image"][0]
|
||||
@@ -88,6 +110,10 @@ class HuggingFaceImageClassification(torch.nn.Module):
|
||||
|
||||
|
||||
def get_hf_img_cls_model(name, import_args):
|
||||
from transformers import AutoModelForImageClassification
|
||||
from transformers import AutoFeatureExtractor
|
||||
import requests
|
||||
|
||||
model = HuggingFaceImageClassification(name)
|
||||
# you can use preprocess_input_image to get the test_input or just random value.
|
||||
test_input = preprocess_input_image(name)
|
||||
|
||||
6
tank/pretrain_models.csv
Normal file
6
tank/pretrain_models.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
|
239
tank/test_models_dynamo.py
Normal file
239
tank/test_models_dynamo.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
get_supported_device_list,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from shark.sharkdynamo.shark_backend import shark_torchdynamo_backend
|
||||
from tank.model_utils import get_training_model
|
||||
from parameterized import parameterized
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch._dynamo as dynamo
|
||||
import transformers
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import csv
|
||||
|
||||
|
||||
def load_csv_and_convert(filename, gen=False):
|
||||
"""
|
||||
takes in a csv filename and generates a dict for consumption by get_valid_test_params
|
||||
"""
|
||||
model_configs = []
|
||||
with open(filename, "r+") as f:
|
||||
reader = csv.reader(f, delimiter=",")
|
||||
for row in reader:
|
||||
if len(row) < 5:
|
||||
print("invalid model: " + row)
|
||||
continue
|
||||
model_configs.append(
|
||||
{
|
||||
"model_name": row[0],
|
||||
"dialect": row[1],
|
||||
"framework": row[2],
|
||||
"rtol": float(row[3]),
|
||||
"atol": float(row[4]),
|
||||
"out_type": row[5],
|
||||
"flags": row[6],
|
||||
"xfail_cpu": row[7],
|
||||
"xfail_cuda": row[8],
|
||||
"xfail_vkm": row[9],
|
||||
"xfail_reason": row[10],
|
||||
"xfail_other": row[11],
|
||||
}
|
||||
)
|
||||
# This is a pytest workaround
|
||||
if gen:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
|
||||
) as out:
|
||||
out.write("ALL = [\n")
|
||||
for c in model_configs:
|
||||
out.write(str(c) + ",\n")
|
||||
out.write("]")
|
||||
return model_configs
|
||||
|
||||
|
||||
def get_valid_test_params(custom_device=None):
|
||||
"""
|
||||
Generate a list of all combinations of available devices and static/dynamic flag.
|
||||
"""
|
||||
device_list = [
|
||||
device
|
||||
for device in get_supported_device_list()
|
||||
if not check_device_drivers(device)
|
||||
]
|
||||
if custom_device:
|
||||
device_list.append(custom_device)
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
# results in strange pytest failures.
|
||||
load_csv_and_convert(
|
||||
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
|
||||
)
|
||||
from tank.dict_configs import ALL
|
||||
|
||||
config_list = ALL
|
||||
|
||||
param_list = [
|
||||
(dynamic, device, config)
|
||||
for dynamic in dynamic_list
|
||||
for device in device_list
|
||||
for config in config_list
|
||||
]
|
||||
|
||||
filtered_param_list = [
|
||||
params for params in param_list if is_valid_case(params)
|
||||
]
|
||||
|
||||
return filtered_param_list
|
||||
|
||||
|
||||
def is_valid_case(test_params):
|
||||
if test_params[0] == True and test_params[2]["framework"] == "tf":
|
||||
return False
|
||||
elif "fp16" in test_params[2]["model_name"] and test_params[1] != "cuda":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def shark_test_name_func(testcase_func, param_num, param):
|
||||
"""
|
||||
Generate function name string which shows dynamic/static and device name.
|
||||
this will be ingested by 'parameterized' package to rename the pytest.
|
||||
"""
|
||||
param_names = []
|
||||
for x in param.args:
|
||||
if x == True:
|
||||
param_names.append("dynamic")
|
||||
elif x == False:
|
||||
param_names.append("static")
|
||||
elif "model" in str(x):
|
||||
as_list = str(x).split(" ")
|
||||
as_list = [
|
||||
parameterized.to_safe_name(x).strip("_") for x in as_list
|
||||
]
|
||||
param_names.insert(0, as_list[as_list.index("model_name") + 1])
|
||||
param_names.insert(1, as_list[as_list.index("framework") + 1])
|
||||
# param_names.append(as_list[3])
|
||||
|
||||
else:
|
||||
param_names.append(x)
|
||||
return "%s_%s" % (
|
||||
testcase_func.__name__,
|
||||
parameterized.to_safe_name("_".join(str(x) for x in param_names)),
|
||||
)
|
||||
|
||||
|
||||
class SharkModuleTester:
|
||||
def __init__(self, config):
|
||||
"""config should be a dict containing minimally:
|
||||
dialect: (str) name of input dialect
|
||||
framework: (str) one of tf, tflite, pytorch
|
||||
model_name: (str) name of the model in the tank ("resnet50")
|
||||
rtol/atol: (float) tolerances for golden values
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
def create_module_sharkdynamo(self, dynamic, device):
|
||||
model_name = self.config["model_name"]
|
||||
model_config = {
|
||||
"batch_size": 128,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"hidden_size": 16,
|
||||
"vocab_size": 8192,
|
||||
}
|
||||
net = get_training_model(model_name, model_config)
|
||||
|
||||
in_dim = 128
|
||||
out_dim = 8
|
||||
|
||||
input_ids = torch.randint(
|
||||
0, 5000, (out_dim, in_dim), dtype=torch.int64
|
||||
)
|
||||
input_mask = torch.ones([out_dim, in_dim], dtype=torch.int64)
|
||||
masked_lm_labels = torch.randint(
|
||||
0, 3000, (out_dim, in_dim), dtype=torch.int64
|
||||
)
|
||||
next_sentence_labels = torch.randint(
|
||||
0, 2, (out_dim,), dtype=torch.int64
|
||||
)
|
||||
segment_ids = torch.randint(0, 2, (out_dim, in_dim), dtype=torch.int64)
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
net.train()
|
||||
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-5)
|
||||
|
||||
def train_func(
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
masked_lm_labels,
|
||||
next_sentence_labels,
|
||||
):
|
||||
loss = net(
|
||||
input_ids=input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
labels=masked_lm_labels,
|
||||
next_sentence_label=next_sentence_labels,
|
||||
).loss
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
optimizer.step()
|
||||
return loss
|
||||
|
||||
torch.manual_seed(0)
|
||||
print("compiling with dynamo...")
|
||||
dynamo_callable = dynamo.optimize(shark_torchdynamo_backend)(
|
||||
train_func
|
||||
)
|
||||
print("running dynamo-compiled module...")
|
||||
res = dynamo_callable(
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
masked_lm_labels,
|
||||
next_sentence_labels,
|
||||
)
|
||||
print("res", res)
|
||||
|
||||
# TODO: add baseline for validation
|
||||
# baseline_res =
|
||||
|
||||
|
||||
class SharkModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.pytestconfig = pytestconfig
|
||||
param_list = get_valid_test_params(
|
||||
custom_device=pytestconfig.getoption("custom_device")
|
||||
)
|
||||
|
||||
param_list = get_valid_test_params()
|
||||
|
||||
@parameterized.expand(param_list, name_func=shark_test_name_func)
|
||||
def test_module(self, dynamic, device, config):
|
||||
self.module_tester = SharkModuleTester(config)
|
||||
self.module_tester.testconfig = self.pytestconfig.args
|
||||
safe_name = (
|
||||
f"{config['model_name']}_dynamo_pretrain_{dynamic}_{device}"
|
||||
)
|
||||
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
|
||||
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=self.module_tester.tmp_prefix, dir="."
|
||||
)
|
||||
self.module_tester.temp_dir = tempdir.name
|
||||
|
||||
with ireec.tools.TempFileSaver(tempdir.name):
|
||||
self.module_tester.create_module_sharkdynamo(dynamic, device)
|
||||
Reference in New Issue
Block a user