Compare commits

...

1 Commits

Author SHA1 Message Date
dan
29923af224 add generate_sharktank for stable_diffusion model defaults 2023-01-25 05:21:55 +00:00
7 changed files with 150 additions and 18 deletions

View File

@@ -19,6 +19,9 @@ import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from shark.examples.shark_inference.stable_diffusion import (
model_wrappers as mw,
)
visible_default = tf.config.list_physical_devices("GPU")
try:
@@ -40,6 +43,21 @@ def create_hash(file_name):
return file_hash.hexdigest()
def get_folder_name(
model_basename, device, precision_value, length, version=None
):
return (
model_basename
+ "_"
+ device
+ "_"
+ precision_value
+ "_maxlen_"
+ length
+ "_torch"
)
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
@@ -62,6 +80,69 @@ def save_torch_model(torch_model_list):
model = None
input = None
if model_type == "fx_imported":
from shark.examples.shark_inference.stable_diffusion.stable_args import (
args,
)
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
base_sd_versions = ["v1_4", "v2_1"]
model_variants = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
]
scheduler_types = [
"PNDM",
"DDIM",
"LMSDiscrete",
"EulerDiscrete",
"DPMSolverMultistep",
"SharkEulerDiscrete",
]
precision_values = ["fp16"]
seq_lengths = [64, 77]
for device in ["vulkan", "cuda"]:
args.device = device
for model_variant in model_variants:
args.variant = model_variant
for base_sd_ver in base_sd_versions:
# model variants not required for non base sd models
if (
base_sd_ver == "v1_4"
and not model_variant == "stablediffusion"
):
continue
else:
args.version = base_sd_ver
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id="stabilityai/stable-diffusion-2-1-base",
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
)
args.max_length = length
model_name = f"{args.variant}/{args.version}/{torch_model_name}/{args.precision}/length_{args.max_length}{args.use_tuned}"
torch_model_dir = os.path.join(
WORKDIR, model_name
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
@@ -243,13 +324,13 @@ if __name__ == "__main__":
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
# if args.tf_model_csv:
# save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
# if args.tflite_model_csv:
# save_tflite_model(args.tflite_model_csv)
if args.upload:
if True:
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
print("uploading files to gs://shark_tank/" + git_hash)
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)

View File

@@ -1,10 +1,13 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx, get_opt_flags
from resources import base_models, variants
from collections import defaultdict
import torch
import sys
# These shapes are parameter dependent.
@@ -71,6 +74,8 @@ class SharkifyStableDiffusionModel:
width: int = 512,
height: int = 512,
use_base_vae: bool = False,
debug: bool = False,
sharktank_dir: str = "",
):
self.check_params(max_len, width, height)
self.inputs = get_model_configuration(
@@ -88,6 +93,8 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
self.debug = debug
self.sharktank_dir = sharktank_dir
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -134,12 +141,18 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
vae_model_name = vae_name + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, vae_model_name), exist_ok=True
)
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
model_name=vae_name + self.model_name,
model_name=vae_model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
debug=self.debug,
)
return shark_vae
@@ -172,13 +185,20 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
unet_model_name = "unet" + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, unet_model_name),
exist_ok=True,
)
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
model_name=unet_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
extra_args=get_opt_flags("unet", precision=self.precision),
debug=self.debug,
)
return shark_unet
@@ -195,12 +215,19 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText()
clip_model_name = "clip" + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, clip_model_name),
exist_ok=True,
)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
model_name=clip_model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
debug=self.debug,
)
return shark_clip

View File

@@ -1,6 +1,6 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/untuned":"gs://shark_tank/latest",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",

View File

@@ -1,7 +1,8 @@
import os
import tempfile
import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.examples.shark_inference.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
@@ -71,10 +72,19 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
extra_args=[],
save_dir=tempfile.gettempdir(),
debug=False,
):
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
save_dir = os.path.join(args.local_tank_cache, model_name)
print("SAVE DIR: " + save_dir)
mlir_module, func_name, = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
save_dir=save_dir,
)
shark_module = SharkInference(
mlir_module,

View File

@@ -130,6 +130,7 @@ class SharkImporter:
):
import numpy as np
print("dir in save data:" + dir)
inputs_name = "inputs.npz"
outputs_name = "golden_out.npz"
func_file_name = "function_name"
@@ -158,6 +159,7 @@ class SharkImporter:
func_name="forward",
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
):
if self.inputs == None:
print(
@@ -177,7 +179,11 @@ class SharkImporter:
if self.frontend in ["torch", "pytorch"]:
import torch
golden_out = self.module(*self.inputs)
golden_out = None
if golden_values is not None:
golden_out = golden_values
else:
golden_out = self.module(*self.inputs)
if torch.is_tensor(golden_out):
golden_out = tuple(
golden_out.detach().cpu().numpy(),
@@ -357,11 +363,16 @@ def import_with_fx(
f16_input_mask=None,
debug=False,
training=False,
save_dir=tempfile.gettempdir(),
model_name="model",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
golden_values = None
if debug:
golden_values = model(*inputs)
# TODO: Control the decompositions.
fx_g = make_fx(
model,
@@ -414,8 +425,10 @@ def import_with_fx(
frontend="torch",
)
if debug and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
if debug: # and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
dir=save_dir, model_name=model_name, golden_values=golden_values
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()

View File

@@ -70,7 +70,7 @@ def get_torch_mlir_module(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=jit_trace,
use_tracing=True,
ignore_traced_shapes=ignore_traced_shapes,
)
bytecode_stream = io.BytesIO()

View File

@@ -18,3 +18,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
stabilityai/stable-diffusion-2-1-base, True,fx_imported,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
18 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
19 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
20 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
21 stabilityai/stable-diffusion-2-1-base True fx_imported False ??M stable diffusion 2.1 base, LLM, Text to image N/A