[SD] Reorganize the stable diffusion model. (#806)

The stable diffusion codebase has been reorganized to make it more
modular so that the same script can be used for web as well as cli,
instead of duplicating the whole codebase.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2023-02-01 04:12:41 +05:30
committed by GitHub
parent e9c744ee5d
commit c124b76328
35 changed files with 2602 additions and 21 deletions

0
apps/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,98 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -0,0 +1,177 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
},
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}
]

View File

@@ -0,0 +1,95 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -0,0 +1 @@
from .txt2img import txt2img_inf

View File

View File

@@ -0,0 +1,240 @@
import os
os.environ["AMD_ENABLE_LLPC"] = "1"
import torch
import re
import time
from pathlib import Path
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs correspoding to it.
def save_output_img(output_img):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
new_entry = {
"VARIANT": args.hf_model_id,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": args.seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
txt2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_model_id: str,
ckpt_file_obj,
precision: str,
device: str,
max_length: int,
):
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.hf_model_id = custom_model_id if custom_model_id else model_id
args.ckpt_loc = ckpt_file_obj.name if ckpt_file_obj else ""
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
)
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
save_output_img(generated_imgs[0])
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
args.seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0])
print(text_output)

View File

@@ -0,0 +1,8 @@
from .utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
)
from .pipelines import Text2ImagePipeline
from .schedulers import get_schedulers

View File

@@ -0,0 +1,2 @@
from .model_wrappers import SharkifyStableDiffusionModel
from .opt_params import get_vae, get_unet, get_clip, get_tokenizer

View File

@@ -0,0 +1,229 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import sys
import traceback
import re
from ..utils import compile_through_fx, get_opt_flags, base_models, args
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
model_name=vae_name + self.model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
def __call__(self):
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
try:
compiled_clip = self.get_clip()
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -0,0 +1,113 @@
import sys
from transformers import CLIPTokenizer
from ..utils import models_db, args, get_shark_model
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
def get_params(bucket_key, model_key, model, is_tuned, precision):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
iree_flags += models_db[2][model][is_tuned][precision][
"default_compilation_flags"
]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
if (
"specified_compilation_flags"
in models_db[2][model][is_tuned][precision]
):
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
][device]
return bucket, model_name, iree_flags
def get_unet():
variant, version = hf_model_variant_map[args.hf_model_id]
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = hf_model_variant_map[args.hf_model_id]
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = hf_model_variant_map[args.hf_model_id]
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -0,0 +1 @@
from .pipeline_shark_stable_diffusion_txt2img import Text2ImagePipeline

View File

@@ -0,0 +1,132 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from ..schedulers import SharkEulerDiscreteScheduler
from .pipeline_shark_stable_diffusion_utils import StableDiffusionPipeline
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -0,0 +1,206 @@
import torch
from transformers import CLIPTokenizer
import torchvision.transforms as T
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from shark.shark_inference import SharkInference
from ..schedulers import SharkEulerDiscreteScheduler
from ..models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from ..utils import start_profiling, end_profiling, preprocessCKPT
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
transform = T.ToPILImage()
pil_images = [
transform(image)
for image in torch.from_numpy(images).to(torch.uint8)
]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
):
init_kwargs = None
if import_mlir:
if ckpt_loc:
preprocessCKPT()
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)

View File

@@ -0,0 +1,2 @@
from .sd_schedulers import get_schedulers
from .shark_eulerdiscrete import SharkEulerDiscreteScheduler

View File

@@ -0,0 +1,51 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
)
from .shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -0,0 +1,139 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from ..utils import compile_through_fx, get_shark_model, args
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -0,0 +1,19 @@
from .profiler import start_profiling, end_profiling
from .resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from .stable_args import args
from .utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
)

View File

@@ -0,0 +1,17 @@
from .stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -0,0 +1,37 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("../../resources/prompts.json")
models_db = get_json_file("../../resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("../../resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("../../resources/opt_flags.json")

View File

@@ -0,0 +1,323 @@
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-p",
"--prompts",
action="append",
default=[],
help="text of which images to be generated.",
)
p.add_argument(
"--negative-prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--runs",
type=int,
default=1,
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
args = p.parse_args()

View File

@@ -0,0 +1,353 @@
import os
import torch
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from .stable_args import args
from .resources import opt_flags
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
extra_args=[],
):
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of a specific setting.
if (
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
or args.precision != "fp16"
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
if args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def preprocessCKPT():
from pathlib import Path
path = Path(args.ckpt_loc)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
print(
"Created directory : ",
diffusers_directory_name,
" at -> ",
diffusers_path,
)
path_to_diffusers = complete_path_to_diffusers.as_posix()
from_safetensors = (
True if args.ckpt_loc.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print("Loading pipeline from original stable diffusion checkpoint")
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=args.ckpt_loc,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
args.ckpt_loc = path_to_diffusers
print("Custom model path is : ", args.ckpt_loc)

View File

@@ -0,0 +1,67 @@
.gradio-container {
background-color: black
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: black;
border-radius: 0 !important;
border: 0;
padding-top: 50px;
padding-bottom: 0px;
width: 460px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: #1d1d1d !important
}
#prompt_examples {
margin: 0 !important
}
#prompt_examples svg {
display: none !important;
}
.gr-sample-textbox {
border-radius: 1rem !important;
border-color: rgb(31, 41, 55) !important;
border-width: 2px !important;
}
#ui_body {
background-color: #111111 !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -0,0 +1,249 @@
import os
import sys
from pathlib import Path
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
import gradio as gr
from PIL import Image
from apps.stable_diffusion.src import (
prompt_examples,
args,
get_available_devices,
)
from apps.stable_diffusion.scripts import txt2img_inf
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = resource_path("css/sd_dark_theme.css")
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
logo2 = Image.open(sdlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
with gr.Column(scale=5, elem_id="demo_title_outer"):
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=150, height=100)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Group():
model_id = gr.Dropdown(
label="Model ID",
value="stabilityai/stable-diffusion-2-1-base",
choices=[
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
custom_model_id = gr.Textbox(
placeholder="check here: https://huggingface.co/models eg. runwayml/stable-diffusion-v1-5",
value="",
label="HuggingFace Model ID",
)
with gr.Group():
ckpt_loc = gr.File(
label="Upload checkpoint",
file_types=[".ckpt", ".safetensors"],
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advance Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
batch_size = gr.Slider(
1, 4, value=1, step=1, label="Number of Images"
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
)
width = gr.Slider(
384, 786, value=512, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=64,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2], height="auto")
std_output = gr.Textbox(
value="Nothing to show.",
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
prompt.submit(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
stable_diffusion.click(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
shark_web.queue()
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -1,21 +1,21 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -294,6 +294,20 @@ p.add_argument(
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################

View File

@@ -321,7 +321,7 @@ def get_available_devices():
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()