mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
4 Commits
ean-dynamo
...
20230617.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
297a209608 | ||
|
|
b204113563 | ||
|
|
f60ab1f4fa | ||
|
|
b203779462 |
@@ -61,6 +61,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -163,7 +163,7 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -415,7 +415,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_cnet, cnet_mlir
|
||||
|
||||
def get_unet(self):
|
||||
def get_unet(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
@@ -452,17 +452,27 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
|
||||
else:
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -471,13 +481,13 @@ class SharkifyStableDiffusionModel:
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_upscaler(self):
|
||||
def get_unet_upscaler(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
@@ -502,6 +512,13 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
@@ -579,16 +596,16 @@ class SharkifyStableDiffusionModel:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler()
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
elif self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
return get_unet()
|
||||
else:
|
||||
return self.get_unet()
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
|
||||
@@ -616,7 +633,7 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self):
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
compiled_unet = None
|
||||
@@ -624,14 +641,14 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
|
||||
@@ -81,6 +81,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -112,7 +113,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -57,6 +57,7 @@ class StableDiffusionPipeline:
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
@@ -114,6 +115,24 @@ class StableDiffusionPipeline:
|
||||
del self.unet
|
||||
self.unet = None
|
||||
|
||||
def load_unet_512(self):
|
||||
if self.unet_512 is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
else:
|
||||
try:
|
||||
self.unet_512 = get_unet(use_large=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
|
||||
def unload_unet_512(self):
|
||||
del self.unet_512
|
||||
self.unet_512 = None
|
||||
|
||||
def load_vae(self):
|
||||
if self.vae is not None:
|
||||
return
|
||||
@@ -203,7 +222,10 @@ class StableDiffusionPipeline:
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
if text_embeddings.shape[1] <= 64:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
@@ -222,16 +244,28 @@ class StableDiffusionPipeline:
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= 64:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -254,6 +288,7 @@ class StableDiffusionPipeline:
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -412,6 +447,11 @@ class StableDiffusionPipeline:
|
||||
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if text_embeddings.shape[1] > 64:
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
|
||||
@@ -108,6 +108,13 @@ p.add_argument(
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
|
||||
@@ -145,7 +145,10 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
if "unet" in model_name.split("_")[0]:
|
||||
if (
|
||||
"unet" in model_name.split("_")[0]
|
||||
or "unet_512" in model_name.split("_")[0]
|
||||
):
|
||||
args.annotation_model = "unet"
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -57,15 +58,19 @@ if __name__ == "__main__":
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
import gradio as gr
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
|
||||
@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
gradio_tmp_galleries_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
@@ -63,19 +60,6 @@ def output_subdirs() -> list[str]:
|
||||
return result_paths
|
||||
|
||||
|
||||
# clear zero length temporary files that gradio 3.22.0 buggily creates
|
||||
# TODO: remove once gradio is upgraded to or past 3.32.0
|
||||
def clear_zero_length_temps():
|
||||
zero_length_temps = [
|
||||
os.path.join(root, file)
|
||||
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
|
||||
for file in files
|
||||
if os.path.getsize(os.path.join(root, file)) == 0
|
||||
]
|
||||
for file in zero_length_temps:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
@@ -105,7 +89,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
visible=False,
|
||||
show_label=True,
|
||||
).style(columns=4)
|
||||
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
@@ -179,7 +162,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
clear_zero_length_temps()
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
@@ -247,7 +229,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
# only update if the current subdir is the most recent one as new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
clear_zero_length_temps()
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ def txt2img_inf(
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
@@ -1,60 +1,54 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import gradio
|
||||
from time import time
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
# clear all gradio tmp files created by generation galleries
|
||||
print(
|
||||
"Clearing gradio temporary image files from a prior run. This may take some time..."
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(gradio_tmp_imgs_folder)
|
||||
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(gradio_tmp_imgs_folder + filename)
|
||||
print(
|
||||
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
else:
|
||||
print("no generation temporary files to clear")
|
||||
|
||||
# Clear all gradio tmp files created by output galleries
|
||||
if os.path.exists(gradio_tmp_galleries_folder):
|
||||
cleanup_start = time()
|
||||
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
print("no output gallery temporary files to clear")
|
||||
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
@@ -90,8 +90,3 @@ def pytest_addoption(parser):
|
||||
type=int,
|
||||
help="Batch size for the tested model.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--custom_device",
|
||||
default=None,
|
||||
help="Custom device string to run tests with.",
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ def get_vendor(triple):
|
||||
return "Intel"
|
||||
if arch in ["turing", "ampere", "pascal"]:
|
||||
return "NVIDIA"
|
||||
if arch == "ardeno":
|
||||
if arch == "adreno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
|
||||
@@ -114,6 +114,11 @@ def get_vulkan_target_triple(device_name):
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
triple = f"adreno-a740-{system_os}"
|
||||
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
|
||||
@@ -525,6 +525,8 @@ def import_with_fx(
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
var_id = 0
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def shark_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
):
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
global var_id
|
||||
with open(f"linalg_gen_{var_id}.mlir", "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(mlir_module)
|
||||
print("saving!")
|
||||
var_id += 1
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.tensor(x)
|
||||
else:
|
||||
result = tuple(torch.tensor(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
@@ -34,11 +34,6 @@ hf_seq2seq_models = [
|
||||
]
|
||||
|
||||
|
||||
def get_training_model(modelname, import_args):
|
||||
if "bert" in modelname:
|
||||
return get_bert_pretrain_model(modelname, import_args)
|
||||
|
||||
|
||||
def get_torch_model(modelname, import_args):
|
||||
if modelname in vision_models:
|
||||
return get_vision_model(modelname, import_args)
|
||||
@@ -52,31 +47,14 @@ def get_torch_model(modelname, import_args):
|
||||
return get_hf_model(modelname, import_args)
|
||||
|
||||
|
||||
##################### Hugging Face BERT PreTraining Models #######################################
|
||||
|
||||
|
||||
def get_bert_pretrain_model(model_name, import_args):
|
||||
from transformers import BertForPreTraining
|
||||
import copy
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_model = BertForPreTraining.from_pretrained(model_name)
|
||||
base_model = base_model.train()
|
||||
my_config = copy.deepcopy(base_model.config)
|
||||
my_config.num_hidden_layers = import_args["num_hidden_layers"]
|
||||
my_config.num_attention_heads = import_args["num_attention_heads"]
|
||||
my_config.hidden_size = import_args["hidden_size"]
|
||||
my_config.vocab_size = import_args["vocab_size"]
|
||||
|
||||
return BertForPreTraining(my_config)
|
||||
|
||||
|
||||
##################### Hugging Face Image Classification Models ###################################
|
||||
from transformers import AutoModelForImageClassification
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
|
||||
def preprocess_input_image(model_name):
|
||||
from PIL import Image
|
||||
|
||||
# from datasets import load_dataset
|
||||
# dataset = load_dataset("huggingface/cats-image")
|
||||
# image1 = dataset["test"]["image"][0]
|
||||
@@ -110,10 +88,6 @@ class HuggingFaceImageClassification(torch.nn.Module):
|
||||
|
||||
|
||||
def get_hf_img_cls_model(name, import_args):
|
||||
from transformers import AutoModelForImageClassification
|
||||
from transformers import AutoFeatureExtractor
|
||||
import requests
|
||||
|
||||
model = HuggingFaceImageClassification(name)
|
||||
# you can use preprocess_input_image to get the test_input or just random value.
|
||||
test_input = preprocess_input_image(name)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
|
@@ -1,239 +0,0 @@
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
get_supported_device_list,
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from shark.sharkdynamo.shark_backend import shark_torchdynamo_backend
|
||||
from tank.model_utils import get_training_model
|
||||
from parameterized import parameterized
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch._dynamo as dynamo
|
||||
import transformers
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import csv
|
||||
|
||||
|
||||
def load_csv_and_convert(filename, gen=False):
|
||||
"""
|
||||
takes in a csv filename and generates a dict for consumption by get_valid_test_params
|
||||
"""
|
||||
model_configs = []
|
||||
with open(filename, "r+") as f:
|
||||
reader = csv.reader(f, delimiter=",")
|
||||
for row in reader:
|
||||
if len(row) < 5:
|
||||
print("invalid model: " + row)
|
||||
continue
|
||||
model_configs.append(
|
||||
{
|
||||
"model_name": row[0],
|
||||
"dialect": row[1],
|
||||
"framework": row[2],
|
||||
"rtol": float(row[3]),
|
||||
"atol": float(row[4]),
|
||||
"out_type": row[5],
|
||||
"flags": row[6],
|
||||
"xfail_cpu": row[7],
|
||||
"xfail_cuda": row[8],
|
||||
"xfail_vkm": row[9],
|
||||
"xfail_reason": row[10],
|
||||
"xfail_other": row[11],
|
||||
}
|
||||
)
|
||||
# This is a pytest workaround
|
||||
if gen:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
|
||||
) as out:
|
||||
out.write("ALL = [\n")
|
||||
for c in model_configs:
|
||||
out.write(str(c) + ",\n")
|
||||
out.write("]")
|
||||
return model_configs
|
||||
|
||||
|
||||
def get_valid_test_params(custom_device=None):
|
||||
"""
|
||||
Generate a list of all combinations of available devices and static/dynamic flag.
|
||||
"""
|
||||
device_list = [
|
||||
device
|
||||
for device in get_supported_device_list()
|
||||
if not check_device_drivers(device)
|
||||
]
|
||||
if custom_device:
|
||||
device_list.append(custom_device)
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
# results in strange pytest failures.
|
||||
load_csv_and_convert(
|
||||
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
|
||||
)
|
||||
from tank.dict_configs import ALL
|
||||
|
||||
config_list = ALL
|
||||
|
||||
param_list = [
|
||||
(dynamic, device, config)
|
||||
for dynamic in dynamic_list
|
||||
for device in device_list
|
||||
for config in config_list
|
||||
]
|
||||
|
||||
filtered_param_list = [
|
||||
params for params in param_list if is_valid_case(params)
|
||||
]
|
||||
|
||||
return filtered_param_list
|
||||
|
||||
|
||||
def is_valid_case(test_params):
|
||||
if test_params[0] == True and test_params[2]["framework"] == "tf":
|
||||
return False
|
||||
elif "fp16" in test_params[2]["model_name"] and test_params[1] != "cuda":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def shark_test_name_func(testcase_func, param_num, param):
|
||||
"""
|
||||
Generate function name string which shows dynamic/static and device name.
|
||||
this will be ingested by 'parameterized' package to rename the pytest.
|
||||
"""
|
||||
param_names = []
|
||||
for x in param.args:
|
||||
if x == True:
|
||||
param_names.append("dynamic")
|
||||
elif x == False:
|
||||
param_names.append("static")
|
||||
elif "model" in str(x):
|
||||
as_list = str(x).split(" ")
|
||||
as_list = [
|
||||
parameterized.to_safe_name(x).strip("_") for x in as_list
|
||||
]
|
||||
param_names.insert(0, as_list[as_list.index("model_name") + 1])
|
||||
param_names.insert(1, as_list[as_list.index("framework") + 1])
|
||||
# param_names.append(as_list[3])
|
||||
|
||||
else:
|
||||
param_names.append(x)
|
||||
return "%s_%s" % (
|
||||
testcase_func.__name__,
|
||||
parameterized.to_safe_name("_".join(str(x) for x in param_names)),
|
||||
)
|
||||
|
||||
|
||||
class SharkModuleTester:
|
||||
def __init__(self, config):
|
||||
"""config should be a dict containing minimally:
|
||||
dialect: (str) name of input dialect
|
||||
framework: (str) one of tf, tflite, pytorch
|
||||
model_name: (str) name of the model in the tank ("resnet50")
|
||||
rtol/atol: (float) tolerances for golden values
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
def create_module_sharkdynamo(self, dynamic, device):
|
||||
model_name = self.config["model_name"]
|
||||
model_config = {
|
||||
"batch_size": 128,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"hidden_size": 16,
|
||||
"vocab_size": 8192,
|
||||
}
|
||||
net = get_training_model(model_name, model_config)
|
||||
|
||||
in_dim = 128
|
||||
out_dim = 8
|
||||
|
||||
input_ids = torch.randint(
|
||||
0, 5000, (out_dim, in_dim), dtype=torch.int64
|
||||
)
|
||||
input_mask = torch.ones([out_dim, in_dim], dtype=torch.int64)
|
||||
masked_lm_labels = torch.randint(
|
||||
0, 3000, (out_dim, in_dim), dtype=torch.int64
|
||||
)
|
||||
next_sentence_labels = torch.randint(
|
||||
0, 2, (out_dim,), dtype=torch.int64
|
||||
)
|
||||
segment_ids = torch.randint(0, 2, (out_dim, in_dim), dtype=torch.int64)
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
net.train()
|
||||
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-5)
|
||||
|
||||
def train_func(
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
masked_lm_labels,
|
||||
next_sentence_labels,
|
||||
):
|
||||
loss = net(
|
||||
input_ids=input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
labels=masked_lm_labels,
|
||||
next_sentence_label=next_sentence_labels,
|
||||
).loss
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
optimizer.step()
|
||||
return loss
|
||||
|
||||
torch.manual_seed(0)
|
||||
print("compiling with dynamo...")
|
||||
dynamo_callable = dynamo.optimize(shark_torchdynamo_backend)(
|
||||
train_func
|
||||
)
|
||||
print("running dynamo-compiled module...")
|
||||
res = dynamo_callable(
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
masked_lm_labels,
|
||||
next_sentence_labels,
|
||||
)
|
||||
print("res", res)
|
||||
|
||||
# TODO: add baseline for validation
|
||||
# baseline_res =
|
||||
|
||||
|
||||
class SharkModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.pytestconfig = pytestconfig
|
||||
param_list = get_valid_test_params(
|
||||
custom_device=pytestconfig.getoption("custom_device")
|
||||
)
|
||||
|
||||
param_list = get_valid_test_params()
|
||||
|
||||
@parameterized.expand(param_list, name_func=shark_test_name_func)
|
||||
def test_module(self, dynamic, device, config):
|
||||
self.module_tester = SharkModuleTester(config)
|
||||
self.module_tester.testconfig = self.pytestconfig.args
|
||||
safe_name = (
|
||||
f"{config['model_name']}_dynamo_pretrain_{dynamic}_{device}"
|
||||
)
|
||||
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
|
||||
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=self.module_tester.tmp_prefix, dir="."
|
||||
)
|
||||
self.module_tester.temp_dir = tempdir.name
|
||||
|
||||
with ireec.tools.TempFileSaver(tempdir.name):
|
||||
self.module_tester.create_module_sharkdynamo(dynamic, device)
|
||||
@@ -24,4 +24,5 @@ bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
Reference in New Issue
Block a user