generate sharktank for apps dir (#966)

* merge confix resolution

* add support to other scripts

---------

Co-authored-by: dan <dan@nod-labs.com>
This commit is contained in:
Daniel Garvey
2023-03-13 12:54:15 -05:00
committed by GitHub
parent 2f133e9d5c
commit 62b5a9fd49
14 changed files with 178 additions and 69 deletions

View File

@@ -166,6 +166,7 @@ def img2img_inf(
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
if use_stencil is not None:
args.use_tuned = False
img2img_obj = StencilPipeline.from_pretrained(
@@ -183,6 +184,7 @@ def img2img_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -199,6 +201,7 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
img2img_obj.scheduler = schedulers[scheduler]
@@ -298,6 +301,7 @@ if __name__ == "__main__":
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
debug=args.import_debug if args.import_mlir else False,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
@@ -314,6 +318,7 @@ if __name__ == "__main__":
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
start_time = time.time()

View File

@@ -125,18 +125,20 @@ def inpaint_inf(
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
inpaint_obj.scheduler = schedulers[scheduler]
@@ -213,18 +215,20 @@ if __name__ == "__main__":
mask_image = Image.open(args.mask_path)
inpaint_obj = InpaintPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
for current_batch in range(args.batch_count):

View File

@@ -120,19 +120,20 @@ def txt2img_inf(
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
txt2img_obj.scheduler = schedulers[scheduler]
@@ -190,21 +191,21 @@ if __name__ == "__main__":
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
)
for current_batch in range(args.batch_count):

View File

@@ -5,6 +5,7 @@ import torch
import safetensors.torch
import traceback
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -92,8 +93,12 @@ class SharkifyStableDiffusionModel:
use_base_vae: bool = False,
use_tuned: bool = False,
low_cpu_mem_usage: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
is_inpaint: bool = False,
use_stencil: str = None
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -114,7 +119,8 @@ class SharkifyStableDiffusionModel:
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
@@ -124,6 +130,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -132,6 +139,11 @@ class SharkifyStableDiffusionModel:
self.is_inpaint = is_inpaint
self.use_stencil = get_stencil_model_id(use_stencil)
print(self.model_name)
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def get_extended_name_for_all_model(self, mask_to_fetch):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
@@ -225,12 +237,18 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
@@ -376,6 +394,12 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet = compile_through_fx(
unet,
inputs,
@@ -383,6 +407,9 @@ class SharkifyStableDiffusionModel:
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
@@ -401,10 +428,19 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name=self.model_name["clip"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip

View File

@@ -317,6 +317,7 @@ class StableDiffusionPipeline:
use_base_vae: bool,
use_tuned: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
):
is_inpaint = cls.__name__ in [
@@ -336,6 +337,7 @@ class StableDiffusionPipeline:
use_base_vae=use_base_vae,
use_tuned=use_tuned,
low_cpu_mem_usage=low_cpu_mem_usage,
debug=debug,
is_inpaint=is_inpaint,
use_stencil=use_stencil,
)

View File

@@ -90,8 +90,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def _import(self):
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model=scaling_model,
inputs=(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,

View File

@@ -18,10 +18,15 @@
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",

View File

@@ -1,4 +1,5 @@
import argparse
import os
from pathlib import Path
@@ -6,6 +7,13 @@ def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
@@ -402,6 +410,12 @@ p.add_argument(
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
)
##############################################################################
### Web UI flags
##############################################################################

View File

@@ -8,6 +8,7 @@ from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -90,6 +91,9 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
):
from shark.parser import shark_args
@@ -97,10 +101,18 @@ def compile_through_fx(
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -112,11 +124,19 @@ def compile_through_fx(
mlir_dialect="linalg",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [

View File

@@ -52,7 +52,7 @@ def save_torch_model(torch_model_list):
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
@@ -105,12 +105,6 @@ def save_torch_model(torch_model_list):
dir=torch_model_dir,
model_name=torch_model_name,
)
mlir_hash = create_hash(
os.path.join(
torch_model_dir, torch_model_name + "_torch" + ".mlir"
)
)
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
@@ -276,6 +270,9 @@ if __name__ == "__main__":
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -13,3 +13,5 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'

View File

@@ -150,10 +150,14 @@ def download_model(
if not check_dir_exists(
model_dir_name, frontend=frontend, dynamic=dyn_str
):
print(f"Downloading artifacts for model {model_name}...")
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
)
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(f"Force-updating artifacts for model {model_name}...")
print(
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
)
download_public_file(full_gs_url, model_dir)
else:
if not _internet_connected():

View File

@@ -4,6 +4,17 @@
import sys
import tempfile
import os
import hashlib
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash.update(chunk)
return file_hash.hexdigest()
# List of the supported frontends.
supported_frontends = {
@@ -140,6 +151,7 @@ class SharkImporter:
outputs_name = "golden_out.npz"
func_file_name = "function_name"
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
print(f"saving {model_name_mlir} to {dir}")
try:
inputs = [x.cpu().detach() for x in inputs]
except AttributeError:
@@ -150,11 +162,11 @@ class SharkImporter:
np.savez(os.path.join(dir, inputs_name), *inputs)
np.savez(os.path.join(dir, outputs_name), *outputs)
np.save(os.path.join(dir, func_file_name), np.array(func_name))
if self.frontend == "torch":
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
mlir_file.write(mlir_data)
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
return
def import_debug(
@@ -377,7 +389,10 @@ def import_with_fx(
golden_values = None
if debug:
golden_values = model(*inputs)
try:
golden_values = model(*inputs)
except:
golden_values = None
# TODO: Control the decompositions.
fx_g = make_fx(
model,

4
tank/torch_sd_list.csv Normal file
View File

@@ -0,0 +1,4 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
4 prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A