Compare commits

...

2 Commits

Author SHA1 Message Date
Ean Garvey
493f776253 add pretrain_models.csv 2023-06-15 14:52:18 -05:00
Ean Garvey
6a82667778 Add pretraining pytest script with SHARK dynamo. 2023-06-15 12:01:35 -05:00
5 changed files with 413 additions and 4 deletions

View File

@@ -90,3 +90,8 @@ def pytest_addoption(parser):
type=int,
help="Batch size for the tested model.",
)
parser.addoption(
"--custom_device",
default=None,
help="Custom device string to run tests with.",
)

View File

@@ -0,0 +1,133 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch
from typing import List
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
torch._dynamo.config.verbose = True
var_id = 0
@make_simple_dynamo_backend
def shark_torchdynamo_backend(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
from contextlib import redirect_stdout
global var_id
with open(f"linalg_gen_{var_id}.mlir", "w") as f:
with redirect_stdout(f):
print(mlir_module)
print("saving!")
var_id += 1
from shark.shark_inference import SharkInference
import io
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.tensor(x)
else:
result = tuple(torch.tensor(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable

View File

@@ -34,6 +34,11 @@ hf_seq2seq_models = [
]
def get_training_model(modelname, import_args):
if "bert" in modelname:
return get_bert_pretrain_model(modelname, import_args)
def get_torch_model(modelname, import_args):
if modelname in vision_models:
return get_vision_model(modelname, import_args)
@@ -47,14 +52,31 @@ def get_torch_model(modelname, import_args):
return get_hf_model(modelname, import_args)
##################### Hugging Face BERT PreTraining Models #######################################
def get_bert_pretrain_model(model_name, import_args):
from transformers import BertForPreTraining
import copy
torch.manual_seed(0)
base_model = BertForPreTraining.from_pretrained(model_name)
base_model = base_model.train()
my_config = copy.deepcopy(base_model.config)
my_config.num_hidden_layers = import_args["num_hidden_layers"]
my_config.num_attention_heads = import_args["num_attention_heads"]
my_config.hidden_size = import_args["hidden_size"]
my_config.vocab_size = import_args["vocab_size"]
return BertForPreTraining(my_config)
##################### Hugging Face Image Classification Models ###################################
from transformers import AutoModelForImageClassification
from transformers import AutoFeatureExtractor
from PIL import Image
import requests
def preprocess_input_image(model_name):
from PIL import Image
# from datasets import load_dataset
# dataset = load_dataset("huggingface/cats-image")
# image1 = dataset["test"]["image"][0]
@@ -88,6 +110,10 @@ class HuggingFaceImageClassification(torch.nn.Module):
def get_hf_img_cls_model(name, import_args):
from transformers import AutoModelForImageClassification
from transformers import AutoFeatureExtractor
import requests
model = HuggingFaceImageClassification(name)
# you can use preprocess_input_image to get the test_input or just random value.
test_input = preprocess_input_image(name)

6
tank/pretrain_models.csv Normal file
View File

@@ -0,0 +1,6 @@
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
1 bert-base-cased linalg torch 1e-2 1e-3 default None False True False
2 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False
3 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True True True
4 bert-large-uncased linalg torch 1e-2 1e-3 default None False True False
5 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False True False
6 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False https://github.com/nod-ai/SHARK/issues/344 macos

239
tank/test_models_dynamo.py Normal file
View File

@@ -0,0 +1,239 @@
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
get_supported_device_list,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from shark.sharkdynamo.shark_backend import shark_torchdynamo_backend
from tank.model_utils import get_training_model
from parameterized import parameterized
import torch
import torch.nn as nn
import torch._dynamo as dynamo
import transformers
import iree.compiler as ireec
import pytest
import unittest
import numpy as np
import tempfile
import os
import sys
import copy
import csv
def load_csv_and_convert(filename, gen=False):
"""
takes in a csv filename and generates a dict for consumption by get_valid_test_params
"""
model_configs = []
with open(filename, "r+") as f:
reader = csv.reader(f, delimiter=",")
for row in reader:
if len(row) < 5:
print("invalid model: " + row)
continue
model_configs.append(
{
"model_name": row[0],
"dialect": row[1],
"framework": row[2],
"rtol": float(row[3]),
"atol": float(row[4]),
"out_type": row[5],
"flags": row[6],
"xfail_cpu": row[7],
"xfail_cuda": row[8],
"xfail_vkm": row[9],
"xfail_reason": row[10],
"xfail_other": row[11],
}
)
# This is a pytest workaround
if gen:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
out.write("]")
return model_configs
def get_valid_test_params(custom_device=None):
"""
Generate a list of all combinations of available devices and static/dynamic flag.
"""
device_list = [
device
for device in get_supported_device_list()
if not check_device_drivers(device)
]
if custom_device:
device_list.append(custom_device)
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL
param_list = [
(dynamic, device, config)
for dynamic in dynamic_list
for device in device_list
for config in config_list
]
filtered_param_list = [
params for params in param_list if is_valid_case(params)
]
return filtered_param_list
def is_valid_case(test_params):
if test_params[0] == True and test_params[2]["framework"] == "tf":
return False
elif "fp16" in test_params[2]["model_name"] and test_params[1] != "cuda":
return False
else:
return True
def shark_test_name_func(testcase_func, param_num, param):
"""
Generate function name string which shows dynamic/static and device name.
this will be ingested by 'parameterized' package to rename the pytest.
"""
param_names = []
for x in param.args:
if x == True:
param_names.append("dynamic")
elif x == False:
param_names.append("static")
elif "model" in str(x):
as_list = str(x).split(" ")
as_list = [
parameterized.to_safe_name(x).strip("_") for x in as_list
]
param_names.insert(0, as_list[as_list.index("model_name") + 1])
param_names.insert(1, as_list[as_list.index("framework") + 1])
# param_names.append(as_list[3])
else:
param_names.append(x)
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name("_".join(str(x) for x in param_names)),
)
class SharkModuleTester:
def __init__(self, config):
"""config should be a dict containing minimally:
dialect: (str) name of input dialect
framework: (str) one of tf, tflite, pytorch
model_name: (str) name of the model in the tank ("resnet50")
rtol/atol: (float) tolerances for golden values
"""
self.config = config
def create_module_sharkdynamo(self, dynamic, device):
model_name = self.config["model_name"]
model_config = {
"batch_size": 128,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"hidden_size": 16,
"vocab_size": 8192,
}
net = get_training_model(model_name, model_config)
in_dim = 128
out_dim = 8
input_ids = torch.randint(
0, 5000, (out_dim, in_dim), dtype=torch.int64
)
input_mask = torch.ones([out_dim, in_dim], dtype=torch.int64)
masked_lm_labels = torch.randint(
0, 3000, (out_dim, in_dim), dtype=torch.int64
)
next_sentence_labels = torch.randint(
0, 2, (out_dim,), dtype=torch.int64
)
segment_ids = torch.randint(0, 2, (out_dim, in_dim), dtype=torch.int64)
torch.set_grad_enabled(True)
net.train()
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-5)
def train_func(
input_ids,
input_mask,
segment_ids,
masked_lm_labels,
next_sentence_labels,
):
loss = net(
input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
labels=masked_lm_labels,
next_sentence_label=next_sentence_labels,
).loss
loss.backward()
optimizer.zero_grad()
optimizer.step()
return loss
torch.manual_seed(0)
print("compiling with dynamo...")
dynamo_callable = dynamo.optimize(shark_torchdynamo_backend)(
train_func
)
print("running dynamo-compiled module...")
res = dynamo_callable(
input_ids,
input_mask,
segment_ids,
masked_lm_labels,
next_sentence_labels,
)
print("res", res)
# TODO: add baseline for validation
# baseline_res =
class SharkModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.pytestconfig = pytestconfig
param_list = get_valid_test_params(
custom_device=pytestconfig.getoption("custom_device")
)
param_list = get_valid_test_params()
@parameterized.expand(param_list, name_func=shark_test_name_func)
def test_module(self, dynamic, device, config):
self.module_tester = SharkModuleTester(config)
self.module_tester.testconfig = self.pytestconfig.args
safe_name = (
f"{config['model_name']}_dynamo_pretrain_{dynamic}_{device}"
)
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
tempdir = tempfile.TemporaryDirectory(
prefix=self.module_tester.tmp_prefix, dir="."
)
self.module_tester.temp_dir = tempdir.name
with ireec.tools.TempFileSaver(tempdir.name):
self.module_tester.create_module_sharkdynamo(dynamic, device)