Compare commits

...

52 Commits

Author SHA1 Message Date
powderluv
ecfdec12f3 Update requirements.txt 2022-12-25 15:39:20 -08:00
Gaurav Shukla
45af40fd14 [SD][web] Add openjourney and dreamlike in SD web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-26 01:59:36 +05:30
Phaneesh Barwaria
d11cf42501 Add support for dreamlike diffusion (#725)
* Add support for dreamlike diffusion

* model wrapper to support 77 dreamlike

* lint fix
2022-12-26 01:35:17 +05:30
Gaurav Shukla
c3c1e3b055 [SD] Add bucket info in the model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 20:38:33 +05:30
Gaurav Shukla
7c5e3b1d99 [SD] Fix flags for cuda devices
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 19:03:02 +05:30
Gaurav Shukla
ed6cec71e7 [SD] Fix clip inference time
Fix clip inference time by adding default warmup_count to 5.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 18:16:53 +05:30
Tobby "GTD-Carthage" Ong
d6bcdd069c - Added missing double linebreak from linting 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
a26347826d - Revised code to also use get_schedulers function instead 2022-12-25 12:07:43 +05:30
Tobby "GTD-Carthage" Ong
5d1c099b31 [SD] Add Euler Ancestral scheduler as option to WebUI 2022-12-25 12:07:43 +05:30
Gaurav Shukla
220bee1365 [SD][web] Add device support in the SD web UI
1. Now device selection is available through UI.
2. Models reloading will only happen when there will be a change in the
   settings(variant + device).

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-25 01:45:07 +05:30
PhaneeshB
1261074d95 Add tuned models for av3 and ad 2022-12-24 22:56:15 +05:30
Stanley Winata
136021424c [SD] Change default VMA large heap block size for windows perf. (#715)
Windows perform can boost from 2.67s/image to 2.4523s/image.
While Linux stays the same.
2022-12-24 01:40:58 +07:00
PhaneeshB
fee4ba3746 Add openjourney 2022-12-23 23:34:22 +05:30
Gaurav Shukla
a5b70335d4 [SD][web] Add variant support in the web UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-23 23:18:27 +05:30
Stanley Winata
5cf4976054 [Vulkan][utils] Add GTX Pascal support. (#709) 2022-12-22 15:24:15 -08:00
PhaneeshB
1aa3255061 Add vaebase for av3 and ad 2022-12-23 04:17:17 +05:30
Daniel Garvey
b01f29f10d add support for clear_all (#691) 2022-12-22 11:25:03 -06:00
Boian Petkantchin
2673abca88 Fix concurrency issue in stress_test for CUDA devices 2022-12-22 08:54:19 -08:00
Gaurav Shukla
7eeb7f0715 [SD] Update all the utilities to make web and CLI codebase closer (#707)
At this point, all the utilities of SD web and CLI are exactly same.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:49:48 -08:00
powderluv
37262a2479 Remove spurious characters 2022-12-21 19:23:54 -08:00
Gaurav Shukla
de6e304959 [SD] Fix the resource location in shark_sd.spec (#706) 2022-12-21 14:41:56 -08:00
Quinn Dawkins
234475bbc7 Add base_vae entries for variant models (#705) 2022-12-21 14:35:08 -08:00
Quinn Dawkins
abbd9f7cfc [SD] Set unet flags for cuda (#704) 2022-12-21 13:22:04 -08:00
Gaurav Shukla
dfd6ba67b3 [SD] Update SD CLI to use model_db.json
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-22 02:13:04 +05:30
yzhang93
1595254eab Modify model annotation tool to walk through ops by shape (#692) 2022-12-21 10:46:30 -08:00
PhaneeshB
6964c5eeba encapsulate relevant methods in one method 2022-12-21 23:56:17 +05:30
PhaneeshB
2befe771b3 Add support for automatic target triple selection for SD 2022-12-21 22:38:06 +05:30
Prashant Kumar
b133a035a4 Add the download progress bar. 2022-12-21 15:47:33 +05:30
Gaurav Shukla
726c062327 [SD] Update spec files
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:16:04 +05:30
Gaurav Shukla
9083672de3 [SD][web] Tuned models only for stablediffusion/fp16 and rdna3 cards
Currently tuned models are only available for stablediffusion/fp16 and
rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-21 14:15:39 +05:30
Quinn Dawkins
cdbaf880af [SD] [web] Add model variants to web 2022-12-21 13:42:22 +05:30
Quinn Dawkins
9434981cdc Add random seed generation for seed = -1 in cli (#689) 2022-12-20 17:15:22 -05:00
Phaneesh Barwaria
8b3706f557 Add Anything v3 and AnalogDiffusion variants of SD (#685)
* base support for anythingv3

* add analogdiffusiont

* Update readme

* keep max len 77 till support for 64 added for variants

* lint fix
2022-12-20 13:08:13 -08:00
Gaurav Shukla
0d5173833d [SD] Add a json file for model names information. (#687)
This commit simplifies the code to identify the model name for a
particular set of flags. This is achieved by introducing a json file
that stores the model names information. The models are uploaded in
gcloud with these names.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 11:47:31 -08:00
powderluv
bf1178eb79 roll to build 400 2022-12-20 10:34:31 -08:00
yzhang93
abcd3fa94a [SD] Set model max length 64 as default (#681) 2022-12-19 21:13:04 -08:00
Quinn Dawkins
62aa1614b6 [SD] Add --use_base_vae flag to do conversion to pixel space on cpu (#682) 2022-12-19 21:09:39 -08:00
Quinn Dawkins
7027356126 [SD] Fix warmup for max length 64 (#680) 2022-12-19 21:04:44 -05:00
yzhang93
5ebe13a13d Add Unet len 64 tuned model (#679) 2022-12-19 16:24:08 -08:00
Gaurav Shukla
c3bed9a2b7 [SD][web] Add flag to disable the progress bar animation
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 02:50:04 +05:30
yzhang93
f865222882 Update VAE 19dec tuned model (#676) 2022-12-19 12:42:28 -08:00
powderluv
e2fe2e4095 Point to 398 2022-12-19 12:08:30 -08:00
powderluv
0532a95f08 Update stable_diffusion_amd.md 2022-12-19 12:04:42 -08:00
Quinn Dawkins
ff536f6015 [SD] Deduplicate initial noise generation (#677) 2022-12-19 14:38:41 -05:00
Gaurav Shukla
097d0f27bb [SD][web] Add 64 max_length support in SD web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-20 00:00:58 +05:30
Prashant Kumar
2257f87edf Update opt_params.py 2022-12-19 23:43:30 +05:30
PhaneeshB
a17800da00 Add 64 len f16 untuned mlir 2022-12-19 22:53:17 +05:30
Prashant Kumar
059c1b3a19 Disable vae --use_tuned version. 2022-12-19 22:45:45 +05:30
Stanley Winata
9a36816d27 [SD][CLI] Add a warmup phase (#670) 2022-12-20 00:14:23 +07:00
Gaurav Shukla
7986b9b20b [SD][WEB] Update VAE model and wrapper
This commit updates VAE model which significantly improves performance
by an order of ~300ms.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Gaurav Shukla
b2b3a0a62b [SD] Move initial latent generation out of inference time
The initial random latent generation is not taken into account
for total SD inference time.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-12-19 22:32:05 +05:30
Prashant Kumar
3173b7d1d9 Update VAE model and wrapper. 2022-12-19 19:54:50 +05:30
32 changed files with 1839 additions and 841 deletions

View File

@@ -1,6 +1,5 @@
setuptools
wheel
pyinstaller
# SHARK Runner
tqdm
@@ -21,3 +20,6 @@ scipy
ftfy
gradio
altair
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pyinstaller

View File

@@ -36,7 +36,9 @@
" from torchdynamo.optimizations.backends import create_backend\n",
" from torchdynamo.optimizations.subgraph import SubGraph\n",
"except ModuleNotFoundError:\n",
" print(\"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\")\n",
" print(\n",
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
" )\n",
" exit()\n",
"\n",
"# torch-mlir imports for compiling\n",
@@ -97,7 +99,9 @@
"\n",
" for node in fx_g.graph.nodes:\n",
" if node.op == \"output\":\n",
" assert len(node.args) == 1, \"Output node must have a single argument\"\n",
" assert (\n",
" len(node.args) == 1\n",
" ), \"Output node must have a single argument\"\n",
" node_arg = node.args[0]\n",
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
" node.args = (node_arg[0],)\n",
@@ -116,8 +120,12 @@
" if len(args) == 1 and isinstance(args[0], list):\n",
" args = args[0]\n",
"\n",
" linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)\n",
" callable, _ = get_iree_compiled_module(linalg_module, \"cuda\", func_name=\"forward\")\n",
" linalg_module = compile(\n",
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
" )\n",
" callable, _ = get_iree_compiled_module(\n",
" linalg_module, \"cuda\", func_name=\"forward\"\n",
" )\n",
"\n",
" def forward(*inputs):\n",
" return callable(*inputs)\n",
@@ -212,6 +220,7 @@
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
"\n",
"\n",
"@torchdynamo.optimize(\"torch_mlir\")\n",
"def toy_example2(*args):\n",
" a, b = args\n",

View File

@@ -42,3 +42,15 @@ unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
```
## Using other supported Stable Diffusion variants with SHARK:
Currently we support the following fine-tuned versions of Stable Diffusion:
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
use the flag `--variant=` to specify the model to be used.
```shell
python .\shark\examples\shark_inference\stable_diffusion\main.py --variant=anythingv3 --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
```

View File

@@ -5,6 +5,7 @@ os.environ["AMD_ENABLE_LLPC"] = "1"
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from PIL import Image
import torchvision.transforms as T
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
@@ -14,8 +15,31 @@ from diffusers import (
)
from tqdm.auto import tqdm
import numpy as np
from random import randint
from stable_args import args
from utils import get_shark_model, set_iree_runtime_flags
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
from utils import set_init_device_flags
from opt_params import get_unet, get_vae, get_clip
from schedulers import (
SharkEulerDiscreteScheduler,
@@ -49,7 +73,7 @@ if __name__ == "__main__":
neg_prompt = args.negative_prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2.1":
if args.version == "v2_1":
height = 768
width = 768
@@ -58,8 +82,14 @@ if __name__ == "__main__":
# Scale for classifier-free guidance
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
# Handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
seed = args.seed
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(
args.seed
seed
) # Seed generator to create the inital latent noise
# TODO: Add support for batch_size > 1.
@@ -69,10 +99,10 @@ if __name__ == "__main__":
if batch_size != len(neg_prompt):
sys.exit("prompts and negative prompts must be of same length")
set_iree_runtime_flags()
set_init_device_flags()
clip = get_clip()
unet = get_unet()
vae = get_vae()
clip = get_clip()
if args.dump_isa:
dump_isas(args.dispatch_benchmarks_dir)
@@ -82,7 +112,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
cpu_scheduling = True
if args.version == "v2.1":
if args.version == "v2_1":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
)
@@ -92,7 +122,7 @@ if __name__ == "__main__":
subfolder="scheduler",
)
if args.version == "v2.1base":
if args.version == "v2_1base" and args.variant == "stablediffusion":
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
)
@@ -110,6 +140,20 @@ if __name__ == "__main__":
subfolder="scheduler",
)
# create a random initial latent.
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
# Warmup phase to improve performance.
if args.warmup_count >= 1:
vae_warmup_input = torch.clone(latents).detach().numpy()
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
for i in range(args.warmup_count):
vae.forward((vae_warmup_input,))
clip.forward((clip_warmup_input,))
start = time.time()
text_input = tokenizer(
@@ -135,21 +179,15 @@ if __name__ == "__main__":
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
latents = torch.randn(
(batch_size, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(num_inference_steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
avg_ms = 0
avg_ms = 0
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
step_start = time.time()
if args.hide_steps == False:
if not args.hide_steps:
print(f"i = {i} t = {t}", end="")
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = scheduler.scale_model_input(latents, t)
@@ -178,34 +216,38 @@ if __name__ == "__main__":
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
if args.hide_steps == False:
if not args.hide_steps:
print(f" ({step_ms}ms)")
avg_ms = 1000 * avg_ms / args.steps
print(f"Average step time: {avg_ms}ms/it")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# latents = latents.
if args.use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
image = vae.forward((latents_numpy,))
images = vae.forward((latents_numpy,))
vae_end = time.time()
end_profiling(profile_device)
image = torch.from_numpy(image)
image = image.detach().cpu().permute(0, 2, 3, 1) * 255.0
images = image.numpy().round().astype("uint8")
total_end = time.time()
if args.use_base_vae:
image = torch.from_numpy(images)
image = (image.detach().cpu() * 255.0).numpy()
images = image.round()
end_time = time.time()
avg_ms = 1000 * avg_ms / args.steps
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
total_time = end_time - start
print(f"\nAverage step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"Total image generation runtime (s): {total_end - start:.4f}")
print(f"\nTotal image generation time: {total_time}sec")
pil_images = [Image.fromarray(image) for image in images]
transform = T.ToPILImage()
pil_images = [
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
]
for i in range(batch_size):
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")

View File

@@ -5,46 +5,67 @@ from stable_args import args
import torch
model_config = {
"v2.1": "stabilityai/stable-diffusion-2-1",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
model_clip_max_length = 77
elif args.variant == "openjourney":
model_clip_max_length = 64
model_variant = {
"stablediffusion": "SD",
"anythingv3": "Linaqruf/anything-v3.0",
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
"openjourney": "prompthero/openjourney",
"analogdiffusion": "wavymulder/Analog-Diffusion",
}
model_input = {
"v2.1": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v2_1": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
torch.randn(1, 4, 96, 96), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v2_1base": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v1_4": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64),
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 768),
torch.randn(2, model_clip_max_length, 768),
torch.tensor(1).to(torch.float32),
),
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
model_revision = {
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
"anythingv3": "diffusers",
"analogdiffusion": "main",
"openjourney": "main",
"dreamlike": "main",
}
def get_clip_mlir(model_name="clip_text", extra_args=[]):
@@ -52,10 +73,25 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version != "v1.4":
if args.variant == "stablediffusion":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
model_variant[args.variant],
subfolder="text_encoder",
revision=model_revision[args.variant],
)
else:
raise ValueError(f"{args.variant} not yet added")
class CLIPText(torch.nn.Module):
def __init__(self):
@@ -75,31 +111,49 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
return shark_clip
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def get_base_vae_mlir(model_name="vae", extra_args=[]):
class BaseVaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision,
revision=model_revision[args.variant],
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
vae = VaeModel()
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
vae = BaseVaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
inputs = model_input[args.version]["vae"]
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
@@ -110,25 +164,53 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
return shark_vae
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
class VaeEncodeModel(torch.nn.Module):
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision="fp16",
revision=model_revision[args.variant],
)
def forward(self, x):
input = 2 * (x - 0.5)
return self.vae.encode(input, return_dict=False)[0]
def forward(self, input):
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
x = x * 255.0
return x.round()
vae = VaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
vae = VaeEncodeModel()
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
)
shark_vae = compile_through_fx(
vae,
inputs,
@@ -143,9 +225,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
revision=model_revision,
revision=model_revision[args.variant],
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -163,16 +247,35 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
return noise_pred
unet = UnetModel()
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
if args.variant == "stablediffusion":
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1_4"]["unet"]
else:
inputs = model_input[args.version]["unet"]
raise ValueError(f"{args.variant} is not yet added")
shark_unet = compile_through_fx(
unet,
inputs,

View File

@@ -1,190 +1,111 @@
import sys
from model_wrappers import (
get_base_vae_mlir,
get_vae_mlir,
get_vae_encode_mlir,
get_unet_mlir,
get_clip_mlir,
)
from resources import models_db
from stable_args import args
from utils import get_shark_model
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
# use tuned models only in the case of rdna3 cards.
if not args.iree_vulkan_target_triple:
vulkan_triple_flags = get_vulkan_triple_flag()
if vulkan_triple_flags and "rdna3" not in vulkan_triple_flags:
args.use_tuned = False
elif "rdna3" not in args.iree_vulkan_target_triple:
args.use_tuned = False
if args.use_tuned:
print("Using tuned models for rdna3 card")
def get_unet():
def get_params(bucket_key, model_key):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
# Tuned model is present for `fp16` precision.
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v1.4":
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16_tuned_v2"
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
if args.version == "v2.1":
model_name = "unet2_14dec_fp16"
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket}/{model_key} is not present in the models database"
)
return bucket, model_name, iree_flags
def get_unet():
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
bucket_key = f"{args.variant}{is_tuned}"
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}{is_tuned}"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
if args.use_tuned:
return get_shark_model(bucket, model_name, iree_flags)
else:
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
# Tuned model is not present for `fp32` case.
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.device == "cuda":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
else:
iree_flags += ["--iree-flow-enable-conv-img2col-transform"]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "int8":
bucket = "gs://shark_tank/prashant_nod"
model_name = "unet_int8"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
sys.exit("int8 model is currently in maintenance.")
# # TODO: Pass iree_flags to the exported model.
# if args.import_mlir:
# sys.exit(
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
# )
# return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.precision in ["fp16", "int8"]:
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16_tuned"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
if args.version == "v2.1":
model_name = "vae2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp32"
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
is_base = "/base" if args.use_base_vae else ""
bucket_key = f"{args.variant}{is_tuned}"
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
if args.use_tuned:
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp16"
if args.version == "v2":
model_name = "vae2_encode_29nov_fp16"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
if args.import_mlir:
return get_vae_encode_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
else:
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_18dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_18dec_fp32"
if args.version == "v2.1":
model_name = "clip2_18dec_fp32"
bucket_key = f"{args.variant}/untuned"
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",

View File

@@ -0,0 +1,31 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
prompt_examples = []
prompts_loc = resource_path("resources/prompts.json")
if os.path.exists(prompts_loc):
with open(prompts_loc, encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
if not prompt_examples:
print("Unable to fetch prompt examples.")
models_db = []
models_loc = resource_path("resources/model_db.json")
if os.path.exists(models_loc):
with open(models_loc, encoding="utf-8") as fopen:
models_db = json.load(fopen)
if len(models_db) != 2:
sys.exit("Error: Unable to load models database.")

View File

@@ -0,0 +1,68 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -46,8 +46,8 @@ p.add_argument(
p.add_argument(
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
@@ -61,7 +61,7 @@ p.add_argument(
p.add_argument(
"--version",
type=str,
default="v2.1base",
default="v2_1base",
help="Specify version of stable diffusion model",
)
@@ -92,11 +92,31 @@ p.add_argument(
p.add_argument(
"--use_tuned",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -117,7 +137,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="2147483648",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -134,7 +154,7 @@ p.add_argument(
p.add_argument(
"--use_compiled_scheduler",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
@@ -173,9 +193,34 @@ p.add_argument(
p.add_argument(
"--hide_steps",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
args = p.parse_args()

View File

@@ -1,26 +1,48 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! If you still can't get it to work, we're sorry, and please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
* https://github.com/iree-org/iree
## Install the latest AMD Drivers
### RDNA2 Drivers:
### AMD KB Drivers for RDNA2 and RDNA3:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, download this special driver in a folder of your choice. We recommend you keep that driver around since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
Note that if you previously tried Stable Diffusion with a different driver, it may be necessary to clear vulkan cache after changing drivers.
For Windows users this can be done by clearing the contents of `C:\Users\<username>\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver's version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, if you use a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [here](https://github.com/nod-ai/SHARK/releases/download/20221216.392/shark_sd_20221216_392.exe). Accept if Windows warns of an unsigned .exe.
Download the latest Windows SHARK SD binary [here](https://github.com/nod-ai/SHARK/releases/download/20221220.400/shark_sd_20221220_400.exe) in a folder of your choice. Please read carefully the following notes:
Notes:
* Your browser may warn you about downloading a exe file
* The first run may take about 10-15 minutes when the models are downloaded and compiled. The download could be about 5GB.
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR, that can get outdated if you run multiple EXE from the same folder.
* Your browser may warn you about downloading an .exe file
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\<username>\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\<username>\.cache\huggingface`.
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
* Please make sure to do the above step before you attempt to update the EXE to a new version.
# Results
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
@@ -35,6 +57,7 @@ Here are some samples generated:
<details>
<summary>Advanced Installation </summary>
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users

View File

@@ -1,10 +1,12 @@
import os
import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -82,7 +84,149 @@ def set_iree_runtime_flags():
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
elif args.use_base_vae and args.variant != "stablediffusion":
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
if args.use_tuned:
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices

View File

@@ -23,21 +23,27 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
if "://" in device:
device = device.split("://")[0]
if device == "cpu":
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
print(
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
if device == "cuda":
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device in ["metal", "vulkan"]:
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
if device == "rocm":
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()

View File

@@ -26,9 +26,10 @@ def get_vulkan_device_name():
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print(
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
)
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
@@ -44,81 +45,83 @@ def get_os_name():
return "linux"
def get_vulkan_triple_flag(extra_args=[]):
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
print(f"Using target triple from command line args")
return None
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
system_os = get_os_name()
vulkan_device = get_vulkan_device_name()
# Apple Targets
if all(x in vulkan_device for x in ("Apple", "M1")):
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
elif all(x in vulkan_device for x in ("Apple", "M2")):
print("Found Apple M2 Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
# Nvidia Targets
elif all(x in vulkan_device for x in ("RTX", "2080")):
print(
f"Found {vulkan_device} Device. Using turing-rtx2080-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx2080-{system_os}"
elif all(x in vulkan_device for x in ("A100", "SXM4")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "3090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "4090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "4000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx4000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx4000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "5000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx5000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx5000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "6000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx6000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx6000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "8000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx8000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx8000-{system_os}"
elif all(x in device_name for x in ("RTX", "2080")):
triple = f"turing-rtx2080-{system_os}"
elif all(x in device_name for x in ("A100", "SXM4")):
triple = f"ampere-rtx3080-{system_os}"
elif all(x in device_name for x in ("RTX", "3090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "4090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "4000")):
triple = f"turing-rtx4000-{system_os}"
elif all(x in device_name for x in ("RTX", "5000")):
triple = f"turing-rtx5000-{system_os}"
elif all(x in device_name for x in ("RTX", "6000")):
triple = f"turing-rtx6000-{system_os}"
elif all(x in device_name for x in ("RTX", "8000")):
triple = f"turing-rtx8000-{system_os}"
elif all(x in device_name for x in ("GTX", "1060")):
triple = f"pascal-gtx1060-{system_os}"
elif all(x in device_name for x in ("GTX", "1070")):
triple = f"pascal-gtx1070-{system_os}"
elif all(x in device_name for x in ("GTX", "1080")):
triple = f"pascal-gtx1080-{system_os}"
# Amd Targets
elif all(x in vulkan_device for x in ("AMD", "7900")):
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
elif all(x in device_name for x in ("AMD", "7900")):
triple = f"rdna3-7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
vulkan_device = (
device_name if device_name is not None else get_vulkan_device_name()
)
triple = get_vulkan_target_triple(vulkan_device)
if triple is not None:
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
)
print(f"Target : {vulkan_device}")
return None
return f"-iree-vulkan-target-triple={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {vulkan_device}")
return None
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
vulkan_flag = []
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag

View File

@@ -22,7 +22,7 @@ from shark.model_annotation import model_annotation
with create_context() as ctx:
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
2. Run model_annotation.py directly
python model_annotation.py path_to_original_mlir path_to_config_file
python model_annotation.py -model path_to_original_mlir -config_path path_to_config_file
"""
import json
@@ -39,21 +39,18 @@ def model_annotation(
*,
input_contents: str,
config_path: str,
search_op: str = "matmul",
search_op: str,
):
if os.path.isfile(input_contents):
with open(input_contents, "rb") as f:
input_contents = f.read()
module = ir.Module.parse(input_contents)
with open(config_path, "r") as f:
data = json.load(f)
configs = data["options"]
configs = load_model_configs(config_path)
# The Python API does not expose a general walk() function, so we just
# do it ourselves.
walk_children(module.operation, configs, 0, search_op)
walk_children(module.operation, configs, search_op)
if not module.operation.verify():
raise RuntimeError("Modified program does not verify!")
@@ -61,15 +58,49 @@ def model_annotation(
return module
def walk_children(
op: ir.Operation, configs: List[Dict], idx: int, search_op: str
):
def load_model_configs(config_path: str):
config = {}
with open(config_path, "r") as f:
for line in f:
data = json.loads(line)
if "identifier" not in data.keys():
continue
if data["identifier"] == "matmul":
matrix_size = [data["m"], data["n"], data["k"]]
elif data["identifier"] == "bmm":
matrix_size = [data["b"], data["m"], data["n"], data["k"]]
elif data["identifier"] == "generic":
matrix_size = [1, data["b"], data["m"], data["n"], data["k"]]
elif data["identifier"] == "conv":
matrix_size = [
data["n"],
data["ih"],
data["iw"],
data["c"],
data["kh"],
data["kw"],
data["f"],
data["oh"],
data["ow"],
data["d"],
data["s"],
data["p"],
]
config[shape_list_to_string(matrix_size)] = data
f.close()
return config
def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
if search_op == "matmul":
op_names = ["linalg.matmul", "mhlo.dot"]
elif search_op == "bmm":
op_names = ["linalg.batch_matmul", "mhlo.dot_general"]
elif search_op == "conv":
op_names = ["mhlo.convolution", "linalg.conv_2d_nhwc_hwcf"]
elif search_op == "generic":
op_names = ["linalg.generic"]
elif search_op == "all":
op_names = [
"mhlo.dot",
@@ -78,6 +109,7 @@ def walk_children(
"linalg.matmul",
"linalg.batch_matmul",
"linalg.conv_2d_nhwc_hwcf",
"linalg.generic",
]
else:
raise ValueError(f"{search_op} op is not tunable.")
@@ -89,37 +121,167 @@ def walk_children(
# 'operation' and 'name' attributes.
if isinstance(child_op, ir.OpView):
child_op = child_op.operation
if child_op.name in op_names and idx < len(configs):
add_attributes(child_op, configs[idx])
idx = idx + 1
if child_op.name in op_names:
if child_op.name == "linalg.generic":
# This is for generic op that has contractionOpInterface
# which is basically einsum("mk,bkn->bmn")
op_result = str(child_op.results[0])
op_iterator = str(
child_op.attributes["iterator_types"]
)
if len(child_op.operands) != 3:
continue
if "reduction" not in op_iterator:
continue
if (
"arith.addf" not in op_result
or "arith.mulf" not in op_result
):
continue
if "arith.subf" in op_result:
continue
child_op_shape = get_op_shape(child_op, search_op)
if (
child_op_shape in configs.keys()
and configs[child_op_shape]["options"][0] != None
):
add_attributes(
child_op, configs[child_op_shape]["options"][0]
)
print(f"Updated op {child_op}", file=sys.stderr)
walk_children(child_op, configs, idx, search_op)
walk_children(child_op, configs, search_op)
def add_attributes(op: ir.Operation, config: Dict):
(
tile_sizes,
pipeline,
workgroup_size,
split_k,
pipeline_depth,
) = parse_config(config)
def get_op_shape(op: ir.Operation, search_op: str):
shape_list = []
if search_op in ["generic", "all"]:
if op.name in ["linalg.generic"]:
input1 = str(op.operands[0].type)
input2 = str(op.operands[1].type)
m = input1.split("tensor<")[1].split("x")[0]
b = input2.split("tensor<")[1].split("x")[0]
k = input2.split("tensor<")[1].split("x")[1]
n = input2.split("tensor<")[1].split("x")[2]
shape_list = [1, int(b), int(m), int(n), int(k)]
add_compilation_info(
op,
tile_sizes=tile_sizes,
pipeline=pipeline,
workgroup_size=workgroup_size,
pipeline_depth=pipeline_depth,
)
if search_op in ["matmul", "all"]:
if op.name in ["mhlo.dot"]:
op_result = str(op.results[0])
m = op_result.split("tensor<")[1].split("x")[0]
k = op_result.split("tensor<")[1].split("x")[1]
n = op_result.split("tensor<")[2].split("x")[1]
shape_list = [int(m), int(n), int(k)]
elif op.name in ["linalg.matmul"]:
op_result = str(op.results[0]).split("ins(")[1]
m = op_result.split("tensor<")[1].split("x")[0]
k = op_result.split("tensor<")[1].split("x")[1]
n = op_result.split("tensor<")[2].split("x")[1]
shape_list = [int(m), int(n), int(k)]
if split_k:
add_attribute_by_name(op, "iree_flow_split_k", split_k)
if search_op in ["bmm", "all"]:
if op.name in ["mhlo.dot_general"]:
op_result = str(op.results[0])
b = op_result.split("tensor<")[1].split("x")[1]
m = op_result.split("tensor<")[1].split("x")[2]
k = op_result.split("tensor<")[1].split("x")[3]
n = op_result.split("tensor<")[3].split("x")[3]
shape_list = [int(b), int(m), int(n), int(k)]
elif op.name in ["linalg.batch_matmul"]:
op_result = str(op.results[0]).split("ins(")[1]
b = op_result.split("tensor<")[1].split("x")[0]
m = op_result.split("tensor<")[1].split("x")[1]
k = op_result.split("tensor<")[1].split("x")[2]
n = op_result.split("tensor<")[3].split("x")[2]
shape_list = [int(b), int(m), int(n), int(k)]
if search_op in ["conv", "all"]:
if op.name in ["mhlo.convolution"]:
op_result = str(op.results[0])
dilation = (
str(op.attributes["rhs_dilation"])
.split("dense<")[1]
.split(">")[0]
)
stride = (
str(op.attributes["window_strides"])
.split("dense<")[1]
.split(">")[0]
)
pad = (
str(op.attributes["padding"]).split("dense<")[1].split(">")[0]
)
n = op_result.split("tensor<")[1].split("x")[0]
ih = op_result.split("tensor<")[1].split("x")[1]
iw = op_result.split("tensor<")[1].split("x")[2]
c = op_result.split("tensor<")[1].split("x")[3]
kh = op_result.split("tensor<")[2].split("x")[0]
kw = op_result.split("tensor<")[2].split("x")[1]
f = op_result.split("tensor<")[2].split("x")[3]
oh = op_result.split("tensor<")[3].split("x")[1]
ow = op_result.split("tensor<")[3].split("x")[2]
shape_list = [
int(n),
int(ih),
int(iw),
int(c),
int(kh),
int(kw),
int(f),
int(oh),
int(ow),
int(dilation),
int(stride),
int(pad),
]
elif op.name in ["linalg.conv_2d_nhwc_hwcf"]:
op_result = str(op.results[0]).split("ins(")[1]
dilation = (
str(op.attributes["dilations"])
.split("dense<")[1]
.split(">")[0]
)
stride = (
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
)
pad = 0
n = op_result.split("tensor<")[1].split("x")[0]
ih = op_result.split("tensor<")[1].split("x")[1]
iw = op_result.split("tensor<")[1].split("x")[2]
c = op_result.split("tensor<")[1].split("x")[3]
kh = op_result.split("tensor<")[2].split("x")[0]
kw = op_result.split("tensor<")[2].split("x")[1]
f = op_result.split("tensor<")[2].split("x")[3]
oh = op_result.split("tensor<")[3].split("x")[1]
ow = op_result.split("tensor<")[3].split("x")[2]
shape_list = [
int(n),
int(ih),
int(iw),
int(c),
int(kh),
int(kw),
int(f),
int(oh),
int(ow),
int(dilation),
int(stride),
int(pad),
]
shape_str = shape_list_to_string(shape_list)
return shape_str
def parse_config(config: Dict):
def add_attributes(op: ir.Operation, config: List[Dict]):
# Parse the config file
split_k = None
pipeline_depth = None
store_stage = None
subgroup_size = None
if "GPU" in config["pipeline"]:
pipeline = (
"LLVMGPUMatmulSimt"
@@ -132,6 +294,10 @@ def parse_config(config: Dict):
pipeline_depth = config["pipeline_depth"]
if "split_k" in config.keys():
split_k = config["split_k"]
if "devices" in config.keys():
devices = config["devices"]
if "shard_sizes" in config.keys():
shard_sizes = config["shard_sizes"]
elif "SPIRV" in config["pipeline"]:
pipeline = config["pipeline"]
tile_sizes = [
@@ -139,11 +305,17 @@ def parse_config(config: Dict):
config["parallel_tile_sizes"],
config["reduction_tile_sizes"],
]
workgroup_size = config["work_group_sizes"]
if "vector_tile_sizes" in config.keys():
tile_sizes += [config["vector_tile_sizes"]]
if "window_tile_sizes" in config.keys():
tile_sizes += [config["window_tile_sizes"]]
workgroup_size = config["work_group_sizes"]
if "subgroup_size" in config.keys():
subgroup_size = config["subgroup_size"]
if "pipeline_depth" in config.keys():
pipeline_depth = config["pipeline_depth"]
if "store_stage" in config.keys():
store_stage = config["store_stage"]
else:
# For IREE CPU pipelines
pipeline = config["pipeline"]
@@ -153,40 +325,45 @@ def parse_config(config: Dict):
config["reduction_tile_sizes"],
]
workgroup_size = []
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
def add_compilation_info(
op: ir.Operation,
tile_sizes: List[List[int]],
pipeline: str,
workgroup_size: List[int],
pipeline_depth: int,
):
# We don't have a Python binding for CompilationInfo, so we just parse
# its string form.
if pipeline_depth:
attr = ir.Attribute.parse(
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{pipeline} pipeline_depth = {pipeline_depth}>, "
f"workgroup_size = {repr(workgroup_size)}>"
)
# Add compilation info as an attribute. We don't have a Python binding for CompilationInfo,
# so we just parse its string form.
if pipeline_depth != None:
translation_info = f"{pipeline} pipeline_depth = {pipeline_depth}"
if store_stage != None:
translation_info += f" store_stage = {store_stage}"
else:
attr = ir.Attribute.parse(
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{pipeline}>, "
f"workgroup_size = {repr(workgroup_size)}>"
)
translation_info = f"{pipeline}"
compilation_info = (
f"#iree_codegen.compilation_info<"
f"lowering_config = <tile_sizes = {repr(tile_sizes)}>, "
f"translation_info = <{translation_info}>, "
f"workgroup_size = {repr(workgroup_size)} "
)
if subgroup_size != None:
compilation_info += f", subgroup_size = {subgroup_size}>"
else:
compilation_info += ">"
attr = ir.Attribute.parse(compilation_info)
op.attributes["compilation_info"] = attr
# Add other attributes if required.
if split_k:
add_attribute_by_name(op, "iree_flow_split_k", split_k)
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
op.attributes[name] = attr
def shape_list_to_string(input):
return "x".join([str(d) for d in input])
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
@@ -195,15 +372,48 @@ def create_context() -> ir.Context:
if __name__ == "__main__":
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
parser = argparse.ArgumentParser()
parser.add_argument(
"-model",
type=path_expand,
default="model.mlir",
help="Path to the input mlir file",
)
parser.add_argument(
"-config_path",
type=path_expand,
default="best_configs.json",
help="Path where stores the op config file",
)
parser.add_argument(
"-output_path",
type=path_expand,
default="tuned_model.mlir",
help="Path to save the annotated mlir file",
)
parser.add_argument(
"-search_op",
type=str,
default="all",
help="Op to be optimized. options are matmul, bmm, conv.",
)
args = parser.parse_args()
with create_context() as ctx:
module = model_annotation(
ctx,
input_contents=sys.argv[1],
config_path=sys.argv[2],
search_op="all",
input_contents=args.model,
config_path=args.config_path,
search_op=args.search_op,
)
mlir_str = str(module)
filename = "tuned_model.mlir"
with open(filename, "w") as f:
with open(args.output_path, "w") as f:
f.write(mlir_str)
print(f"Saved mlir in {filename}.")
print(f"Saved mlir in {args.output_path}.")

View File

@@ -14,6 +14,7 @@
import numpy as np
import os
from tqdm.std import tqdm
import sys
from pathlib import Path
from shark.parser import shark_args
@@ -52,12 +53,18 @@ def download_public_file(
destination_filename = os.path.join(
destination_folder_name, dest_filename
)
blob.download_to_filename(destination_filename)
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(
f, "write", total=blob.size
) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
else:
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
blob.download_to_filename(destination_filename)
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
input_type_to_np_dtype = {

View File

@@ -46,20 +46,29 @@ def stress_test_compiled_model(
logging.info(
f"Running stress test {stress_test_index} on device {device}."
)
shark_module = SharkInference(
mlir_module=bytes(), function_name=function_name, device=device
)
shark_module.load_module(shark_module_path)
# All interactions with the module must run in a single thread.
# We are using execution in a sperate thread in order to be able
# to wait with a timeout on the inference operation.
module_executor = ThreadPoolExecutor(1)
shark_module = module_executor.submit(
SharkInference,
mlir_module=bytes(),
function_name=function_name,
device=device,
).result()
module_executor.submit(
shark_module.load_module, shark_module_path
).result()
input_batches = [np.repeat(arr, batch_size, axis=0) for arr in inputs]
golden_output_batches = np.repeat(golden_out, batch_size, axis=0)
report_interval_seconds = 10
start_time = time.time()
previous_report_time = start_time
executor = ThreadPoolExecutor(1)
first_iteration_output = None
for i in range(max_iterations):
inference_task = executor.submit(shark_module.forward, input_batches)
output = inference_task.result(inference_timeout_seconds)
output = module_executor.submit(
shark_module.forward, input_batches
).result(inference_timeout_seconds)
if first_iteration_output is None:
np.testing.assert_array_almost_equal_nulp(
golden_output_batches, output, nulp=tolerance_nulp
@@ -149,14 +158,24 @@ def stress_test(
if device_names is None or device_types is not None:
device_names = [] if device_names is None else device_names
with ProcessPoolExecutor() as executor:
# query_devices needs to run in a separate process,
# because it will interfere with other processes that are forked later.
device_names.extend(
executor.submit(query_devices, device_types).result()
)
device_types_set = list(set(get_device_types(device_names)))
shark_module_paths_set = compile_stress_test_module(
device_types_set, mlir_model, func_name, mlir_dialect
)
with ProcessPoolExecutor() as executor:
# This needs to run in a subprocess because when compiling for CUDA,
# some stuff get intialized and cuInit will fail in a forked process
# later. It should be just compiling, but alas.
shark_module_paths_set = executor.submit(
compile_stress_test_module,
device_types_set,
mlir_model,
func_name,
mlir_dialect,
).result()
device_type_shark_module_path_map = {
device_type: module_path
for device_type, module_path in zip(

View File

@@ -1,40 +1,17 @@
import os
os.environ["AMD_ENABLE_LLPC"] = "1"
# from models.resnet50 import resnet_inf
# from models.albert_maskfill import albert_maskfill_inf
from models.stable_diffusion.main import stable_diff_inf
# from models.diffusion.v_diffusion import vdiff_inf
import gradio as gr
from PIL import Image
import json
import os
import sys
from random import randint
import numpy as np
os.environ["AMD_ENABLE_LLPC"] = "1"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
prompt_examples = []
prompt_loc = resource_path("prompts.json")
if os.path.exists(prompt_loc):
with open(prompt_loc, encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
from models.stable_diffusion.resources import resource_path, prompt_examples
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.utils import get_available_devices
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = """
.gradio-container {background-color: black}
.container {background-color: black !important; padding-top:20px !important; }
@@ -56,6 +33,7 @@ demo_css = """
footer {display: none !important;}
"""
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
@@ -94,24 +72,18 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
lines=1,
elem_id="prompt_box",
)
with gr.Group():
ex = gr.Examples(
label="Examples",
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Row():
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="Guidance Scale",
variant = gr.Dropdown(
label="Model Variant",
value="stablediffusion",
choices=[
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
],
)
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
@@ -121,31 +93,44 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
with gr.Group():
random_seed = gr.Button("Randomize Seed").style(
full_width=True
)
uint32_info = np.iinfo(np.uint32)
random_val = randint(uint32_info.min, uint32_info.max)
seed = gr.Number(
value=random_val, precision=0, show_label=False
)
u32_min = gr.Number(
value=uint32_info.min, visible=False
)
u32_max = gr.Number(
value=uint32_info.max, visible=False
)
random_seed.click(
None,
inputs=[u32_min, u32_max],
outputs=[seed],
_js="(min,max) => Math.floor(Math.random() * (max - min)) + min",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Row():
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device_key = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!"):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
generated_img = gr.Image(
@@ -166,8 +151,11 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
guidance_scale,
seed,
scheduler_key,
variant,
device_key,
],
outputs=[generated_img, std_output],
show_progress=args.progress_bar,
)
stable_diffusion.click(
stable_diff_inf,
@@ -178,8 +166,11 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
guidance_scale,
seed,
scheduler_key,
variant,
device_key,
],
outputs=[generated_img, std_output],
show_progress=args.progress_bar,
)
shark_web.queue()

View File

@@ -1,77 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
#datas += copy_metadata('iree')
datas += collect_data_files('shark')
datas += [
( 'prompts.json', '.' ),
( 'logos/*', 'logos' )
]
block_cipher = None
a = Analysis(
['index.py'],
pathex=['.'],
binaries=[],
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='shark_sd',
)

View File

@@ -5,63 +5,97 @@ from diffusers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
)
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import set_iree_runtime_flags
from models.stable_diffusion.utils import (
set_init_device_flags,
set_iree_runtime_flags,
)
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.schedulers import (
SharkEulerDiscreteScheduler,
)
# set iree-runtime flags
set_iree_runtime_flags()
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"] = SharkEulerDiscreteScheduler.from_pretrained(
model_config[args.version],
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
cache_obj = dict()
# cache vae, unet and clip.
(
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
) = (get_vae(), get_unet(), get_clip())
# cache tokenizer
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version == "v2.1base":
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
def get_schedulers(version):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_config[version],
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers
def get_tokenizer(version):
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
if version != "v1_4":
tokenizer = CLIPTokenizer.from_pretrained(
model_config[version], subfolder="tokenizer"
)
return tokenizer
class ModelCache:
def __init__(self):
self.device = None
self.variant = None
self.version = None
self.schedulers = None
self.tokenizer = None
def set_models(self, device_key):
if self.device != device_key or self.variant != args.variant:
self.device = device_key
self.variant = args.variant
self.version = args.version
args.device = device_key.split("=>", 1)[0].strip()
args.max_length = 64
args.use_tuned = True
set_init_device_flags()
self.schedulers = get_schedulers(args.version)
self.tokenizer = get_tokenizer(args.version)
self.vae = get_vae()
self.unet = get_unet()
self.clip = get_clip()
model_cache = ModelCache()

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

View File

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 10 KiB

View File

Before

Width:  |  Height:  |  Size: 5.0 KiB

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -1,22 +1,68 @@
import torch
import os
from PIL import Image
import torchvision.transforms as T
from tqdm.auto import tqdm
from models.stable_diffusion.cache_objects import (
cache_obj,
schedulers,
)
from models.stable_diffusion.cache_objects import model_cache
from models.stable_diffusion.stable_args import args
from random import randint
import numpy as np
import time
import sys
def set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed):
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()
def set_ui_params(
prompt,
negative_prompt,
steps,
guidance_scale,
seed,
scheduler_key,
variant,
):
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.steps = steps
args.guidance_scale = guidance_scale
args.guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
args.seed = seed
args.scheduler = scheduler_key
args.variant = variant
def stable_diff_inf(
@@ -26,16 +72,24 @@ def stable_diff_inf(
guidance_scale: float,
seed: int,
scheduler_key: str,
variant: str,
device_key: str,
):
# Handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed)
set_ui_params(
prompt,
negative_prompt,
steps,
guidance_scale,
seed,
scheduler_key,
variant,
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
generator = torch.manual_seed(
args.seed
@@ -44,19 +98,31 @@ def stable_diff_inf(
# set height and width.
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
if args.version == "v2.1":
if args.version == "v2_1":
height = 768
width = 768
# Initialize vae and unet models.
vae, unet, clip, tokenizer = (
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
cache_obj["tokenizer"],
)
scheduler = schedulers[scheduler_key]
cpu_scheduling = not scheduler_key.startswith("Shark")
# get all cached data.
model_cache.set_models(device_key)
tokenizer = model_cache.tokenizer
scheduler = model_cache.schedulers[args.scheduler]
vae, unet, clip = model_cache.vae, model_cache.unet, model_cache.clip
cpu_scheduling = not args.scheduler.startswith("Shark")
# create a random initial latent.
latents = torch.randn(
(1, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
# Warmup phase to improve performance.
if args.warmup_count >= 1:
vae_warmup_input = torch.clone(latents).detach().numpy()
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
for i in range(args.warmup_count):
vae.forward((vae_warmup_input,))
clip.forward((clip_warmup_input,))
start = time.time()
text_input = tokenizer(
@@ -82,19 +148,12 @@ def stable_diff_inf(
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
latents = torch.randn(
(1, 4, height // 8, width // 8),
generator=generator,
dtype=torch.float32,
).to(dtype)
scheduler.set_timesteps(args.steps)
scheduler.is_scale_input_called = True
latents = latents * scheduler.init_noise_sigma
avg_ms = 0
out_img = None
for i, t in tqdm(enumerate(scheduler.timesteps)):
step_start = time.time()
@@ -103,6 +162,7 @@ def stable_diff_inf(
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = unet.forward(
(
latent_model_input,
@@ -112,6 +172,7 @@ def stable_diff_inf(
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
@@ -121,37 +182,46 @@ def stable_diff_inf(
step_time = time.time() - step_start
avg_ms += step_time
step_ms = int((step_time) * 1000)
print(f" \nIteration = {i}, Time = {step_ms}ms")
if not args.hide_steps:
print(f" \nIteration = {i}, Time = {step_ms}ms")
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
if args.use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
image = vae.forward((latents_numpy,))
images = vae.forward((latents_numpy,))
vae_end = time.time()
image = torch.from_numpy(image)
image = (image.detach().cpu().permute(0, 2, 3, 1) * 255.0).numpy()
images = image.round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
out_img = pil_images[0]
end_profiling(profile_device)
if args.use_base_vae:
image = torch.from_numpy(images)
image = (image.detach().cpu() * 255.0).numpy()
images = image.round()
end_time = time.time()
avg_ms = 1000 * avg_ms / args.steps
total_time = time.time() - start
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")
text_output += "\nTotal image generation time: {0:.2f}sec".format(
total_time
)
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
vae_inf_time = (vae_end - vae_start) * 1000
total_time = end_time - start
print(f"\nAverage step time: {avg_ms}ms/it")
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
print(f"\nTotal image generation time: {total_time}sec")
return out_img, text_output
# generate outputs to web.
transform = T.ToPILImage()
pil_images = [
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
]
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nvariant={args.variant}, scheduler={args.scheduler}, device={device_key}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={height}x{width}"
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return pil_images[0], text_output

View File

@@ -5,46 +5,67 @@ from models.stable_diffusion.stable_args import args
import torch
model_config = {
"v2.1": "stabilityai/stable-diffusion-2-1",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
"v2_1": "stabilityai/stable-diffusion-2-1",
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
"v1_4": "CompVis/stable-diffusion-v1-4",
}
# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
model_clip_max_length = 77
elif args.variant == "openjourney":
model_clip_max_length = 64
model_variant = {
"stablediffusion": "SD",
"anythingv3": "Linaqruf/anything-v3.0",
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
"openjourney": "prompthero/openjourney",
"analogdiffusion": "wavymulder/Analog-Diffusion",
}
model_input = {
"v2.1": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v2_1": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 96, 96),),
"unet": (
torch.randn(1, 4, 96, 96), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v2_1base": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.randn(2, model_clip_max_length, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (2, 77)),),
"v1_4": {
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64),
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 768),
torch.randn(2, model_clip_max_length, 768),
torch.tensor(1).to(torch.float32),
),
},
}
# revision param for from_pretrained defaults to "main" => fp32
model_revision = "fp16" if args.precision == "fp16" else "main"
model_revision = {
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
"anythingv3": "diffusers",
"analogdiffusion": "main",
"openjourney": "main",
"dreamlike": "main",
}
def get_clip_mlir(model_name="clip_text", extra_args=[]):
@@ -52,10 +73,25 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
if args.version != "v1.4":
if args.variant == "stablediffusion":
if args.version != "v1_4":
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
)
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
text_encoder = CLIPTextModel.from_pretrained(
model_config[args.version], subfolder="text_encoder"
model_variant[args.variant],
subfolder="text_encoder",
revision=model_revision[args.variant],
)
else:
raise ValueError(f"{args.variant} not yet added")
class CLIPText(torch.nn.Module):
def __init__(self):
@@ -75,31 +111,49 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
return shark_clip
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def get_base_vae_mlir(model_name="vae", extra_args=[]):
class BaseVaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision=model_revision,
revision=model_revision[args.variant],
)
def forward(self, input):
x = self.vae.decode(input, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
vae = VaeModel()
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
vae = BaseVaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
inputs = model_input[args.version]["vae"]
raise ValueError(f"{args.variant} not yet added")
shark_vae = compile_through_fx(
vae,
@@ -110,25 +164,53 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
return shark_vae
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
class VaeEncodeModel(torch.nn.Module):
def get_vae_mlir(model_name="vae", extra_args=[]):
class VaeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="vae",
revision="fp16",
revision=model_revision[args.variant],
)
def forward(self, x):
input = 2 * (x - 0.5)
return self.vae.encode(input, return_dict=False)[0]
def forward(self, input):
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
x = x * 255.0
return x.round()
vae = VaeModel()
if args.variant == "stablediffusion":
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[
inputs.half().cuda()
for inputs in model_input[args.version]["vae"]
]
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
)
else:
inputs = model_input["v1_4"]["vae"]
else:
raise ValueError(f"{args.variant} not yet added")
vae = VaeEncodeModel()
vae = vae.half().cuda()
inputs = tuple(
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
)
shark_vae = compile_through_fx(
vae,
inputs,
@@ -143,9 +225,11 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
def __init__(self):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_config[args.version],
model_config[args.version]
if args.variant == "stablediffusion"
else model_variant[args.variant],
subfolder="unet",
revision=model_revision,
revision=model_revision[args.variant],
)
self.in_channels = self.unet.in_channels
self.train(False)
@@ -163,16 +247,35 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
return noise_pred
unet = UnetModel()
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
if args.variant == "stablediffusion":
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input[args.version]["unet"]
]
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
[
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
for inputs in model_input["v1_4"]["unet"]
]
)
else:
inputs = model_input["v1_4"]["unet"]
else:
inputs = model_input[args.version]["unet"]
raise ValueError(f"{args.variant} is not yet added")
shark_unet = compile_through_fx(
unet,
inputs,

View File

@@ -1,191 +1,111 @@
import sys
from models.stable_diffusion.model_wrappers import (
get_base_vae_mlir,
get_vae_mlir,
get_vae_encode_mlir,
get_unet_mlir,
get_clip_mlir,
)
from models.stable_diffusion.resources import models_db
from models.stable_diffusion.stable_args import args
from models.stable_diffusion.utils import get_shark_model
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
# use tuned models only in the case of rdna3 cards.
args.use_tuned = False
if not args.iree_vulkan_target_triple:
vulkan_triple_flags = get_vulkan_triple_flag()
if vulkan_triple_flags and "rdna3" in vulkan_triple_flags:
args.use_tuned = True
elif "rdna3" in args.iree_vulkan_target_triple:
args.use_tuned = True
if args.use_tuned:
print("Using tuned models for rdna3 card")
def get_unet():
def get_params(bucket_key, model_key):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
# Tuned model is present for `fp16` precision.
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v1.4":
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16_tuned_v2"
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
if args.version == "v2.1":
model_name = "unet2_14dec_fp16"
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket}/{model_key} is not present in the models database"
)
return bucket, model_name, iree_flags
def get_unet():
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
bucket_key = f"{args.variant}{is_tuned}"
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}{is_tuned}"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
if args.use_tuned:
return get_shark_model(bucket, model_name, iree_flags)
else:
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
# Tuned model is not present for `fp32` case.
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.device == "cuda":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
else:
iree_flags += ["--iree-flow-enable-conv-img2col-transform"]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_unet_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "int8":
bucket = "gs://shark_tank/prashant_nod"
model_name = "unet_int8"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
]
sys.exit("int8 model is currently in maintenance.")
# # TODO: Pass iree_flags to the exported model.
# if args.import_mlir:
# sys.exit(
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
# )
# return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.precision in ["fp16", "int8"]:
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16_tuned"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
if args.version == "v2.1":
model_name = "vae2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp32"
# Tuned model is present only for `fp16` precision.
is_tuned = "/tuned" if args.use_tuned else "/untuned"
is_base = "/base" if args.use_base_vae else ""
bucket_key = f"{args.variant}{is_tuned}"
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
if args.use_tuned:
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae_encode():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp16"
if args.version == "v2":
model_name = "vae2_encode_29nov_fp16"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
if args.import_mlir:
return get_vae_encode_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_encode_1dec_fp32"
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
else:
if args.precision == "fp16":
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
elif args.precision == "fp32":
iree_flags += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir:
if args.use_base_vae:
return get_base_vae_mlir(model_name, iree_flags)
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_18dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_18dec_fp32"
if args.version == "v2.1":
model_name = "clip2_18dec_fp32"
bucket_key = f"{args.variant}/untuned"
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
bucket, model_name, iree_flags = get_params(bucket_key, model_key)
iree_flags += [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",

View File

@@ -0,0 +1,31 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
prompt_examples = []
prompts_loc = resource_path("resources/prompts.json")
if os.path.exists(prompts_loc):
with open(prompts_loc, encoding="utf-8") as fopen:
prompt_examples = json.load(fopen)
if not prompt_examples:
print("Unable to fetch prompt examples.")
models_db = []
models_loc = resource_path("resources/model_db.json")
if os.path.exists(models_loc):
with open(models_loc, encoding="utf-8") as fopen:
models_db = json.load(fopen)
if len(models_db) != 2:
sys.exit("Error: Unable to load models database.")

View File

@@ -0,0 +1,68 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae2_19dec_fp16",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -0,0 +1,8 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -46,8 +46,8 @@ p.add_argument(
p.add_argument(
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
@@ -61,7 +61,7 @@ p.add_argument(
p.add_argument(
"--version",
type=str,
default="v2.1base",
default="v2_1base",
help="Specify version of stable diffusion model",
)
@@ -92,11 +92,31 @@ p.add_argument(
p.add_argument(
"--use_tuned",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--variant",
default="stablediffusion",
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -117,7 +137,7 @@ p.add_argument(
p.add_argument(
"--vulkan_large_heap_block_size",
default="2147483648",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
@@ -132,6 +152,13 @@ p.add_argument(
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
@@ -163,4 +190,37 @@ p.add_argument(
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=5,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
args = p.parse_args()

View File

@@ -1,10 +1,12 @@
import os
import torch
from shark.shark_inference import SharkInference
from models.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -82,7 +84,149 @@ def set_iree_runtime_flags():
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if "vulkan" in args.device:
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
elif args.use_base_vae and args.variant != "stablediffusion":
args.use_tuned = False
print("Tuned models are currently not supported for this setting.")
if args.use_tuned:
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{driver_name}://{i} => {device['name']}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices

View File

@@ -24,8 +24,9 @@ datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'prompts.json', '.' ),
( 'logos/*.png', 'logos' )
( 'models/stable_diffusion/resources/prompts.json', 'resources' ),
( 'models/stable_diffusion/resources/model_db.json', 'resources' ),
( 'models/stable_diffusion/logos/*', 'logos' )
]
binaries = []