Compare commits

..

4 Commits

Author SHA1 Message Date
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
19 changed files with 161 additions and 509 deletions

View File

@@ -61,6 +61,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -163,7 +163,7 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
index = 0
for model in sub_model_list:
sub_model = model
@@ -415,7 +415,7 @@ class SharkifyStableDiffusionModel:
)
return shark_cnet, cnet_mlir
def get_unet(self):
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
super().__init__()
@@ -452,17 +452,27 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
else:
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -471,13 +481,13 @@ class SharkifyStableDiffusionModel:
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
super().__init__()
@@ -502,6 +512,13 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
input_mask = [True, True, True, False]
shark_unet, unet_mlir = compile_through_fx(
unet,
@@ -579,16 +596,16 @@ class SharkifyStableDiffusionModel:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict
def compile_unet_variants(self, model):
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
return get_unet()
else:
return self.get_unet()
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
@@ -616,7 +633,7 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def unet(self):
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
@@ -624,14 +641,14 @@ class SharkifyStableDiffusionModel:
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")

View File

@@ -81,6 +81,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -112,7 +113,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -57,6 +57,7 @@ class StableDiffusionPipeline:
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
@@ -114,6 +115,24 @@ class StableDiffusionPipeline:
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
@@ -203,7 +222,10 @@ class StableDiffusionPipeline:
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
if text_embeddings.shape[1] <= 64:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -222,16 +244,28 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= 64:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -254,6 +288,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -412,6 +447,11 @@ class StableDiffusionPipeline:
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > 64:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:

View File

@@ -108,6 +108,13 @@ p.add_argument(
help="max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,

View File

@@ -145,7 +145,10 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
if (
"unet" in model_name.split("_")[0]
or "unet_512" in model_name.split("_")[0]
):
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id

View File

@@ -1,6 +1,7 @@
from multiprocessing import Process, freeze_support
import os
import sys
import shutil
import transformers # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -57,15 +58,19 @@ if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
config_gradio_tmp_imgs_folder,
)
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):

View File

@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.gradio_configs import (
gradio_tmp_galleries_folder,
)
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
@@ -63,19 +60,6 @@ def output_subdirs() -> list[str]:
return result_paths
# clear zero length temporary files that gradio 3.22.0 buggily creates
# TODO: remove once gradio is upgraded to or past 3.32.0
def clear_zero_length_temps():
zero_length_temps = [
os.path.join(root, file)
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
for file in files
if os.path.getsize(os.path.join(root, file)) == 0
]
for file in zero_length_temps:
os.remove(file)
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
@@ -105,7 +89,6 @@ with gr.Blocks() as outputgallery_web:
visible=False,
show_label=True,
).style(columns=4)
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
with gr.Column(scale=4):
with gr.Box():
@@ -179,7 +162,6 @@ with gr.Blocks() as outputgallery_web:
# --- Event handlers
def on_clear_gallery():
clear_zero_length_temps()
return [
gr.Gallery.update(
value=[],
@@ -247,7 +229,6 @@ with gr.Blocks() as outputgallery_web:
# only update if the current subdir is the most recent one as new images only go there
if subdir_paths[0] == subdir:
clear_zero_length_temps()
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"

View File

@@ -193,6 +193,7 @@ def txt2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time

View File

@@ -1,60 +1,54 @@
import os
import shutil
import tempfile
import gradio
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
# clear all gradio tmp files created by generation galleries
print(
"Clearing gradio temporary image files from a prior run. This may take some time..."
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
image_files = [
filename
for filename in os.listdir(gradio_tmp_imgs_folder)
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
for filename in image_files:
os.remove(gradio_tmp_imgs_folder + filename)
print(
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
else:
print("no generation temporary files to clear")
# Clear all gradio tmp files created by output galleries
if os.path.exists(gradio_tmp_galleries_folder):
cleanup_start = time()
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
print("no output gallery temporary files to clear")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")

View File

@@ -90,8 +90,3 @@ def pytest_addoption(parser):
type=int,
help="Batch size for the tested model.",
)
parser.addoption(
"--custom_device",
default=None,
help="Custom device string to run tests with.",
)

View File

@@ -136,7 +136,7 @@ def get_vendor(triple):
return "Intel"
if arch in ["turing", "ampere", "pascal"]:
return "NVIDIA"
if arch == "ardeno":
if arch == "adreno":
return "Qualcomm"
if arch == "cpu":
if product == "swiftshader":

View File

@@ -114,6 +114,11 @@ def get_vulkan_target_triple(device_name):
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
# Adreno Targets
elif all(x in device_name for x in ("Adreno", "740")):
triple = f"adreno-a740-{system_os}"
else:
triple = None
return triple

View File

@@ -525,6 +525,8 @@ def import_with_fx(
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
),
)(*inputs)

View File

@@ -1,133 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
import torch
from typing import List
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
torch._dynamo.config.verbose = True
var_id = 0
@make_simple_dynamo_backend
def shark_torchdynamo_backend(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
if _returns_nothing(fx_graph):
return fx_graph
removed_none_indexes = _remove_nones(fx_graph)
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
mlir_module = torch_mlir.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors"
)
from contextlib import redirect_stdout
global var_id
with open(f"linalg_gen_{var_id}.mlir", "w") as f:
with redirect_stdout(f):
print(mlir_module)
print("saving!")
var_id += 1
from shark.shark_inference import SharkInference
import io
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
shark_module.compile()
def compiled_callable(*inputs):
inputs = [x.numpy() for x in inputs]
result = shark_module("forward", inputs)
if was_unwrapped:
result = [
result,
]
if not isinstance(result, list):
result = torch.tensor(x)
else:
result = tuple(torch.tensor(x) for x in result)
result = list(result)
for removed_index in removed_none_indexes:
result.insert(removed_index, None)
result = tuple(result)
return result
return compiled_callable

View File

@@ -34,11 +34,6 @@ hf_seq2seq_models = [
]
def get_training_model(modelname, import_args):
if "bert" in modelname:
return get_bert_pretrain_model(modelname, import_args)
def get_torch_model(modelname, import_args):
if modelname in vision_models:
return get_vision_model(modelname, import_args)
@@ -52,31 +47,14 @@ def get_torch_model(modelname, import_args):
return get_hf_model(modelname, import_args)
##################### Hugging Face BERT PreTraining Models #######################################
def get_bert_pretrain_model(model_name, import_args):
from transformers import BertForPreTraining
import copy
torch.manual_seed(0)
base_model = BertForPreTraining.from_pretrained(model_name)
base_model = base_model.train()
my_config = copy.deepcopy(base_model.config)
my_config.num_hidden_layers = import_args["num_hidden_layers"]
my_config.num_attention_heads = import_args["num_attention_heads"]
my_config.hidden_size = import_args["hidden_size"]
my_config.vocab_size = import_args["vocab_size"]
return BertForPreTraining(my_config)
##################### Hugging Face Image Classification Models ###################################
from transformers import AutoModelForImageClassification
from transformers import AutoFeatureExtractor
from PIL import Image
import requests
def preprocess_input_image(model_name):
from PIL import Image
# from datasets import load_dataset
# dataset = load_dataset("huggingface/cats-image")
# image1 = dataset["test"]["image"][0]
@@ -110,10 +88,6 @@ class HuggingFaceImageClassification(torch.nn.Module):
def get_hf_img_cls_model(name, import_args):
from transformers import AutoModelForImageClassification
from transformers import AutoFeatureExtractor
import requests
model = HuggingFaceImageClassification(name)
# you can use preprocess_input_image to get the test_input or just random value.
test_input = preprocess_input_image(name)

View File

@@ -1,6 +0,0 @@
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
1 bert-base-cased linalg torch 1e-2 1e-3 default None False True False
2 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False
3 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True True True
4 bert-large-uncased linalg torch 1e-2 1e-3 default None False True False
5 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False True False
6 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False https://github.com/nod-ai/SHARK/issues/344 macos

View File

@@ -1,239 +0,0 @@
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
get_supported_device_list,
)
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from shark.sharkdynamo.shark_backend import shark_torchdynamo_backend
from tank.model_utils import get_training_model
from parameterized import parameterized
import torch
import torch.nn as nn
import torch._dynamo as dynamo
import transformers
import iree.compiler as ireec
import pytest
import unittest
import numpy as np
import tempfile
import os
import sys
import copy
import csv
def load_csv_and_convert(filename, gen=False):
"""
takes in a csv filename and generates a dict for consumption by get_valid_test_params
"""
model_configs = []
with open(filename, "r+") as f:
reader = csv.reader(f, delimiter=",")
for row in reader:
if len(row) < 5:
print("invalid model: " + row)
continue
model_configs.append(
{
"model_name": row[0],
"dialect": row[1],
"framework": row[2],
"rtol": float(row[3]),
"atol": float(row[4]),
"out_type": row[5],
"flags": row[6],
"xfail_cpu": row[7],
"xfail_cuda": row[8],
"xfail_vkm": row[9],
"xfail_reason": row[10],
"xfail_other": row[11],
}
)
# This is a pytest workaround
if gen:
with open(
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
) as out:
out.write("ALL = [\n")
for c in model_configs:
out.write(str(c) + ",\n")
out.write("]")
return model_configs
def get_valid_test_params(custom_device=None):
"""
Generate a list of all combinations of available devices and static/dynamic flag.
"""
device_list = [
device
for device in get_supported_device_list()
if not check_device_drivers(device)
]
if custom_device:
device_list.append(custom_device)
dynamic_list = (True, False)
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
# results in strange pytest failures.
load_csv_and_convert(
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
)
from tank.dict_configs import ALL
config_list = ALL
param_list = [
(dynamic, device, config)
for dynamic in dynamic_list
for device in device_list
for config in config_list
]
filtered_param_list = [
params for params in param_list if is_valid_case(params)
]
return filtered_param_list
def is_valid_case(test_params):
if test_params[0] == True and test_params[2]["framework"] == "tf":
return False
elif "fp16" in test_params[2]["model_name"] and test_params[1] != "cuda":
return False
else:
return True
def shark_test_name_func(testcase_func, param_num, param):
"""
Generate function name string which shows dynamic/static and device name.
this will be ingested by 'parameterized' package to rename the pytest.
"""
param_names = []
for x in param.args:
if x == True:
param_names.append("dynamic")
elif x == False:
param_names.append("static")
elif "model" in str(x):
as_list = str(x).split(" ")
as_list = [
parameterized.to_safe_name(x).strip("_") for x in as_list
]
param_names.insert(0, as_list[as_list.index("model_name") + 1])
param_names.insert(1, as_list[as_list.index("framework") + 1])
# param_names.append(as_list[3])
else:
param_names.append(x)
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name("_".join(str(x) for x in param_names)),
)
class SharkModuleTester:
def __init__(self, config):
"""config should be a dict containing minimally:
dialect: (str) name of input dialect
framework: (str) one of tf, tflite, pytorch
model_name: (str) name of the model in the tank ("resnet50")
rtol/atol: (float) tolerances for golden values
"""
self.config = config
def create_module_sharkdynamo(self, dynamic, device):
model_name = self.config["model_name"]
model_config = {
"batch_size": 128,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"hidden_size": 16,
"vocab_size": 8192,
}
net = get_training_model(model_name, model_config)
in_dim = 128
out_dim = 8
input_ids = torch.randint(
0, 5000, (out_dim, in_dim), dtype=torch.int64
)
input_mask = torch.ones([out_dim, in_dim], dtype=torch.int64)
masked_lm_labels = torch.randint(
0, 3000, (out_dim, in_dim), dtype=torch.int64
)
next_sentence_labels = torch.randint(
0, 2, (out_dim,), dtype=torch.int64
)
segment_ids = torch.randint(0, 2, (out_dim, in_dim), dtype=torch.int64)
torch.set_grad_enabled(True)
net.train()
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-5)
def train_func(
input_ids,
input_mask,
segment_ids,
masked_lm_labels,
next_sentence_labels,
):
loss = net(
input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
labels=masked_lm_labels,
next_sentence_label=next_sentence_labels,
).loss
loss.backward()
optimizer.zero_grad()
optimizer.step()
return loss
torch.manual_seed(0)
print("compiling with dynamo...")
dynamo_callable = dynamo.optimize(shark_torchdynamo_backend)(
train_func
)
print("running dynamo-compiled module...")
res = dynamo_callable(
input_ids,
input_mask,
segment_ids,
masked_lm_labels,
next_sentence_labels,
)
print("res", res)
# TODO: add baseline for validation
# baseline_res =
class SharkModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.pytestconfig = pytestconfig
param_list = get_valid_test_params(
custom_device=pytestconfig.getoption("custom_device")
)
param_list = get_valid_test_params()
@parameterized.expand(param_list, name_func=shark_test_name_func)
def test_module(self, dynamic, device, config):
self.module_tester = SharkModuleTester(config)
self.module_tester.testconfig = self.pytestconfig.args
safe_name = (
f"{config['model_name']}_dynamo_pretrain_{dynamic}_{device}"
)
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
tempdir = tempfile.TemporaryDirectory(
prefix=self.module_tester.tmp_prefix, dir="."
)
self.module_tester.temp_dir = tempdir.name
with ireec.tools.TempFileSaver(tempdir.name):
self.module_tester.create_module_sharkdynamo(dynamic, device)

View File

@@ -24,4 +24,5 @@ bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
24 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
25 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
26 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
27 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
28 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -