Compare commits

...

1 Commits

Author SHA1 Message Date
dan
498c1dbfe0 generate sharktank for apps dir
also adds support for the sub-models
2023-02-08 21:32:50 +00:00
6 changed files with 91 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ import torch
import traceback
import re
import sys
import os
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -73,6 +74,9 @@ class SharkifyStableDiffusionModel:
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
debug: bool = False,
sharktank_dir: str = "",
generate_vmfb: bool = True,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -89,7 +93,8 @@ class SharkifyStableDiffusionModel:
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
"_"
+ str(batch_size)
+ "_"
+ str(max_len)
+ "_"
@@ -113,6 +118,9 @@ class SharkifyStableDiffusionModel:
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
self.debug = debug
self.sharktank_dir = sharktank_dir
self.generate_vmfb = generate_vmfb
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
@@ -146,12 +154,19 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
vae_model_name = vae_name + self.model_name
save_dir = os.path.join(self.sharktank_dir, vae_model_name)
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=vae_name + self.model_name,
model_name=vae_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
@@ -185,13 +200,23 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
unet_model_name = "unet" + self.model_name
save_dir = os.path.join(self.sharktank_dir, unet_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
model_name=unet_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
@@ -209,10 +234,20 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText()
clip_model_name = "clip" + self.model_name
save_dir = os.path.join(self.sharktank_dir, clip_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
model_name=clip_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip

View File

@@ -1,4 +1,5 @@
import argparse
import os
from pathlib import Path
@@ -6,6 +7,13 @@ def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

View File

@@ -3,6 +3,7 @@ import gc
from pathlib import Path
import numpy as np
from random import randint
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -83,6 +84,9 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
):
from shark.parser import shark_args
@@ -90,10 +94,18 @@ def compile_through_fx(
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
(
mlir_module,
func_name,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
save_dir=save_dir,
)
if use_tuned:
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
@@ -109,13 +121,13 @@ def compile_through_fx(
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():

View File

@@ -106,12 +106,6 @@ def save_torch_model(torch_model_list):
dir=torch_model_dir,
model_name=torch_model_name,
)
mlir_hash = create_hash(
os.path.join(
torch_model_dir, torch_model_name + "_torch" + ".mlir"
)
)
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
# Generate torch dynamic models.
if is_dynamic:
mlir_importer.import_debug(
@@ -278,5 +272,8 @@ if __name__ == "__main__":
)
save_torch_model(torch_model_csv)
save_torch_model(
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -4,6 +4,17 @@
import sys
import tempfile
import os
import hashlib
def create_hash(file_name):
with open(file_name, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(2**20):
file_hash.update(chunk)
return file_hash.hexdigest()
# List of the supported frontends.
supported_frontends = {
@@ -150,11 +161,11 @@ class SharkImporter:
np.savez(os.path.join(dir, inputs_name), *inputs)
np.savez(os.path.join(dir, outputs_name), *outputs)
np.save(os.path.join(dir, func_file_name), np.array(func_name))
if self.frontend == "torch":
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
mlir_file.write(mlir_data)
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
return
def import_debug(

6
tank/torch_sd_list.csv Normal file
View File

@@ -0,0 +1,6 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
anythingv3/v1_4,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
analogdiffusion/v1_4,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
openjourney/v1_4",True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
Can't render this file because it contains an unexpected character in line 6 and column 17.