mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 15:28:10 -05:00
Compare commits
1 Commits
diffusers-
...
20230208.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
498c1dbfe0 |
@@ -5,6 +5,7 @@ import torch
|
||||
import traceback
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -73,6 +74,9 @@ class SharkifyStableDiffusionModel:
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
debug: bool = False,
|
||||
sharktank_dir: str = "",
|
||||
generate_vmfb: bool = True,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -89,7 +93,8 @@ class SharkifyStableDiffusionModel:
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
str(batch_size)
|
||||
"_"
|
||||
+ str(batch_size)
|
||||
+ "_"
|
||||
+ str(max_len)
|
||||
+ "_"
|
||||
@@ -113,6 +118,9 @@ class SharkifyStableDiffusionModel:
|
||||
if model_name[0] == "_":
|
||||
model_name = model_name[1:]
|
||||
self.model_name = self.model_name + "_" + model_name
|
||||
self.debug = debug
|
||||
self.sharktank_dir = sharktank_dir
|
||||
self.generate_vmfb = generate_vmfb
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
@@ -146,12 +154,19 @@ class SharkifyStableDiffusionModel:
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
vae_name = "base_vae" if self.base_vae else "vae"
|
||||
vae_model_name = vae_name + self.model_name
|
||||
save_dir = os.path.join(self.sharktank_dir, vae_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=vae_name + self.model_name,
|
||||
model_name=vae_model_name,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae
|
||||
@@ -185,13 +200,23 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
unet_model_name = "unet" + self.model_name
|
||||
save_dir = os.path.join(self.sharktank_dir, unet_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name="unet" + self.model_name,
|
||||
model_name=unet_model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
)
|
||||
return shark_unet
|
||||
@@ -209,10 +234,20 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
clip_model_name = "clip" + self.model_name
|
||||
save_dir = os.path.join(self.sharktank_dir, clip_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name="clip" + self.model_name,
|
||||
model_name=clip_model_name,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -6,6 +7,13 @@ def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import gc
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -83,6 +84,9 @@ def compile_through_fx(
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
@@ -90,10 +94,18 @@ def compile_through_fx(
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
if not os.path.exists(tuned_model_path):
|
||||
@@ -109,13 +121,13 @@ def compile_through_fx(
|
||||
with open(tuned_model_path, "rb") as f:
|
||||
mlir_module = f.read()
|
||||
f.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
if generate_vmfb:
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
|
||||
@@ -106,12 +106,6 @@ def save_torch_model(torch_model_list):
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
)
|
||||
)
|
||||
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
@@ -278,5 +272,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
save_torch_model(torch_model_csv)
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
|
||||
)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
|
||||
@@ -4,6 +4,17 @@
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
# List of the supported frontends.
|
||||
supported_frontends = {
|
||||
@@ -150,11 +161,11 @@ class SharkImporter:
|
||||
np.savez(os.path.join(dir, inputs_name), *inputs)
|
||||
np.savez(os.path.join(dir, outputs_name), *outputs)
|
||||
np.save(os.path.join(dir, func_file_name), np.array(func_name))
|
||||
|
||||
if self.frontend == "torch":
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
|
||||
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
|
||||
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
|
||||
return
|
||||
|
||||
def import_debug(
|
||||
|
||||
6
tank/torch_sd_list.csv
Normal file
6
tank/torch_sd_list.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
anythingv3/v1_4,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
analogdiffusion/v1_4,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
openjourney/v1_4",True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
|
Can't render this file because it contains an unexpected character in line 6 and column 17.
|
Reference in New Issue
Block a user