mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[APPS-SD] Fix a few bugs and bring it up to speed with SD CLI (#908)
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -170,6 +170,5 @@ tank/dict_configs.py
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
#web logging
|
||||
web/logs/
|
||||
web/stored_results/stable_diffusion/
|
||||
# Generated images
|
||||
generated_imgs/
|
||||
|
||||
@@ -41,6 +41,12 @@ if args.clear_all:
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
|
||||
@@ -6,4 +6,6 @@ from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_unet,
|
||||
get_clip,
|
||||
get_tokenizer,
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
|
||||
@@ -2,14 +2,15 @@ from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import sys
|
||||
import traceback
|
||||
import re
|
||||
import os, sys, functools, operator
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
base_models,
|
||||
args,
|
||||
get_vmfb_path_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +69,7 @@ class SharkifyStableDiffusionModel:
|
||||
height: int = 512,
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -88,6 +90,7 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
self.use_tuned = use_tuned
|
||||
# We need a better naming convention for the .vmfbs because despite
|
||||
# using the custom model variant the .vmfb names remain the same and
|
||||
# it'll always pick up the compiled .vmfb instead of compiling the
|
||||
@@ -95,6 +98,7 @@ class SharkifyStableDiffusionModel:
|
||||
# So, currently, we add `self.model_id` in the `self.model_name` of
|
||||
# .vmfb file.
|
||||
# TODO: Have a better way of naming the vmfbs using self.model_name.
|
||||
import re
|
||||
|
||||
model_name = re.sub(r"\W+", "_", self.model_id)
|
||||
if model_name[0] == "_":
|
||||
@@ -137,6 +141,7 @@ class SharkifyStableDiffusionModel:
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=vae_name + self.model_name,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
@@ -177,6 +182,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_name="unet" + self.model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
)
|
||||
return shark_unet
|
||||
@@ -194,7 +200,6 @@ class SharkifyStableDiffusionModel:
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
@@ -204,6 +209,11 @@ class SharkifyStableDiffusionModel:
|
||||
return shark_clip
|
||||
|
||||
def __call__(self):
|
||||
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
|
||||
vmfb_path = [
|
||||
get_vmfb_path_name(model + self.model_name)[0]
|
||||
for model in model_name
|
||||
]
|
||||
for model_id in base_models:
|
||||
self.inputs = get_input_info(
|
||||
base_models[model_id],
|
||||
@@ -213,12 +223,22 @@ class SharkifyStableDiffusionModel:
|
||||
self.batch_size,
|
||||
)
|
||||
try:
|
||||
compiled_clip = self.get_clip()
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
except Exception as e:
|
||||
if args.enable_stack_trace:
|
||||
traceback.print_exc()
|
||||
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
|
||||
all_vmfb_present = functools.reduce(
|
||||
operator.__and__, vmfb_present
|
||||
)
|
||||
# We need to delete vmfbs only if some of the models were compiled.
|
||||
if not all_vmfb_present:
|
||||
for i in range(len(vmfb_path)):
|
||||
if vmfb_present[i]:
|
||||
os.remove(vmfb_path[i])
|
||||
print("Deleted: ", vmfb_path[i])
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
|
||||
|
||||
@@ -14,6 +14,10 @@ hf_model_variant_map = {
|
||||
}
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
@@ -60,7 +64,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
|
||||
|
||||
def get_unet():
|
||||
variant, version = hf_model_variant_map[args.hf_model_id]
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
@@ -77,7 +81,7 @@ def get_unet():
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = hf_model_variant_map[args.hf_model_id]
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
@@ -95,7 +99,7 @@ def get_vae():
|
||||
|
||||
|
||||
def get_clip():
|
||||
variant, version = hf_model_variant_map[args.hf_model_id]
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
bucket_key = f"{variant}/untuned"
|
||||
model_key = (
|
||||
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
|
||||
@@ -185,10 +185,12 @@ class StableDiffusionPipeline:
|
||||
width: int,
|
||||
use_base_vae: bool,
|
||||
):
|
||||
init_kwargs = None
|
||||
if import_mlir:
|
||||
if ckpt_loc:
|
||||
preprocessCKPT()
|
||||
if ckpt_loc != "":
|
||||
assert ckpt_loc.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
ckpt_loc = preprocessCKPT()
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
|
||||
@@ -9,8 +9,10 @@ from apps.stable_diffusion.src.utils.resources import (
|
||||
opt_flags,
|
||||
resource_path,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_vmfb_path_name,
|
||||
get_shark_model,
|
||||
compile_through_fx,
|
||||
set_iree_runtime_flags,
|
||||
|
||||
@@ -1,95 +1,101 @@
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32"
|
||||
],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": ["--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32"
|
||||
],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
}
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
206
apps/stable_diffusion/src/utils/sd_annotation.py
Normal file
206
apps/stable_diffusion/src/utils/sd_annotation.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import os
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import iree_target_map, run_cmd
|
||||
from shark.shark_downloader import (
|
||||
download_model,
|
||||
download_public_file,
|
||||
WORKDIR,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
def get_device():
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
# Download the model (Unet or VAE fp16) from shark_tank
|
||||
def load_model_from_tank():
|
||||
from apps.stable_diffusion.src.models import (
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
|
||||
version, variant = get_variant_version(args.hf_model_id)
|
||||
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{variant}/untuned"
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, args.annotation_model, "untuned", args.precision
|
||||
)
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=bucket,
|
||||
frontend="torch",
|
||||
)
|
||||
return mlir_model, model_name
|
||||
|
||||
|
||||
# Download the tuned config files from shark_tank
|
||||
def load_winograd_configs():
|
||||
device = get_device()
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
|
||||
version, variant = get_variant_version(args.hf_model_id)
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_version = version
|
||||
if variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
config_version = "v1_4"
|
||||
if args.annotation_model == "vae":
|
||||
args.max_length = 77
|
||||
device = get_device()
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=input_mlir,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=True,
|
||||
)
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
return winograd_model, out_file_path
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
if use_winograd:
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
else:
|
||||
dump_after = "iree-flow-pad-linalg-ops"
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
device_spec_args = ""
|
||||
device = get_device()
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args += flag + " "
|
||||
elif device == "vulkan":
|
||||
device_spec_args = (
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
print("Applying tuned configs on", model_name)
|
||||
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(device)} "
|
||||
f"{device_spec_args}"
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
"--iree-flow-enable-padding-linalg-ops "
|
||||
"--iree-flow-linalg-ops-padding-size=32 "
|
||||
"--iree-flow-enable-conv-img2col-transform "
|
||||
f"--mlir-print-ir-after={dump_after} "
|
||||
"--compile-to=flow "
|
||||
f"2>{args.annotation_output}/dump_after_winograd.mlir "
|
||||
)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
# Remove the intermediate mlir and save the final annotated model
|
||||
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
return tuned_model, out_file_path
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model, model_path = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model, output_path = annotate_with_lower_configs(
|
||||
model_path, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
elif args.annotation_model == "vae" and device == "vulkan":
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
tuned_model, output_path = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
if model_from_tank:
|
||||
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
|
||||
else:
|
||||
# Just use this function to convert bytecode to string
|
||||
orig_model, model_path = annotate_with_winograd(
|
||||
mlir_model, "", model_name
|
||||
)
|
||||
mlir_model = model_path
|
||||
lowering_config_dir = load_lower_configs()
|
||||
tuned_model, output_path = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
print(f"Saved the annotated mlir in {output_path}.")
|
||||
return tuned_model, output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlir_model, model_name = load_model_from_tank()
|
||||
sd_model_annotation(mlir_model, model_name, model_from_tank=True)
|
||||
@@ -23,7 +23,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -9,21 +10,27 @@ from shark.iree_utils.vulkan_utils import (
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
return [vmfb_path, extended_name]
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -73,17 +80,40 @@ def compile_through_fx(
|
||||
model_name,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
extra_args=[],
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
model_name = model_name + "_tuned"
|
||||
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
if not os.path.exists(tuned_model_path):
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
|
||||
tuned_model, tuned_model_path = sd_model_annotation(
|
||||
mlir_module, model_name
|
||||
)
|
||||
del mlir_module, tuned_model
|
||||
gc.collect()
|
||||
|
||||
with open(tuned_model_path, "rb") as f:
|
||||
mlir_module = f.read()
|
||||
f.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
@@ -203,11 +233,15 @@ def set_init_device_flags():
|
||||
elif args.hf_model_id == "prompthero/openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# Use tuned models in the case of a specific setting.
|
||||
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
|
||||
if (
|
||||
args.hf_model_id
|
||||
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
|
||||
or args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
@@ -217,7 +251,12 @@ def set_init_device_flags():
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in [
|
||||
"sm_80",
|
||||
"sm_84",
|
||||
"sm_86",
|
||||
"sm_89",
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
@@ -296,6 +335,11 @@ def get_opt_flags(model, precision="fp16"):
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"default_compilation_flags"
|
||||
]
|
||||
|
||||
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
device = (
|
||||
args.device
|
||||
@@ -312,13 +356,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
|
||||
return iree_flags
|
||||
|
||||
|
||||
def preprocessCKPT():
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(args.ckpt_loc)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
@@ -347,5 +388,5 @@ def preprocessCKPT():
|
||||
)
|
||||
pipe.save_pretrained(path_to_diffusers)
|
||||
print("Loading complete")
|
||||
args.ckpt_loc = path_to_diffusers
|
||||
print("Custom model path is : ", args.ckpt_loc)
|
||||
print("Custom model path is : ", path_to_diffusers)
|
||||
return path_to_diffusers
|
||||
|
||||
Reference in New Issue
Block a user