mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
5 Commits
20230123.4
...
20230110.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea5fa6db6 | ||
|
|
ec7b19d41b | ||
|
|
1fd43d1219 | ||
|
|
b78187635d | ||
|
|
a4d28110b0 |
1
.github/workflows/test-models.yml
vendored
1
.github/workflows/test-models.yml
vendored
@@ -115,7 +115,6 @@ jobs:
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
sh build_tools/stable_diff_main_test.sh
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
import argparse
|
||||
import torchvision
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
import shutil
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-n", "--newfile")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--golden_url",
|
||||
default="https://storage.googleapis.com/shark_tank/testdata/cyberpunk_fores_42_0_230119_021148.png",
|
||||
)
|
||||
|
||||
|
||||
def get_image(url, local_filename):
|
||||
res = requests.get(url, stream=True)
|
||||
if res.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
return torchvision.io.read_image(local_filename).numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
new = torchvision.io.read_image(args.newfile).numpy() / 255.0
|
||||
tempfile_name = os.path.join(os.getcwd(), "golden.png")
|
||||
golden = get_image(args.golden_url, tempfile_name) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if not mean < 0.2:
|
||||
subprocess.run(
|
||||
["gsutil", "cp", args.newfile, "gs://shark_tank/testdata/builder/"]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
@@ -1,6 +0,0 @@
|
||||
rm -rf ./test_images
|
||||
mkdir test_images
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
|
||||
|
||||
python build_tools/image_comparison.py -n ./test_images/*.png
|
||||
exit $?
|
||||
@@ -1,23 +0,0 @@
|
||||
# Dataset annotation tool
|
||||
|
||||
SHARK annotator for adding or modifying prompts of dataset images
|
||||
|
||||
## Set up
|
||||
|
||||
Activate SHARK Python virtual environment and install additional packages
|
||||
```shell
|
||||
source ../shark.venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Run annotator
|
||||
|
||||
```shell
|
||||
python annotation_tool.py
|
||||
```
|
||||
* Select dataset from `Dataset` dropdown list
|
||||
* Select image from `Image` dropdown list
|
||||
* Image and the existing prompt will be loaded
|
||||
* Add or modify prompt in `Prompt` textbox which will be autosaved
|
||||
* Click `Next` to load the next image, you could also select other images from `Image`
|
||||
* Click `Finish` when finishing annotation or before switching dataset
|
||||
@@ -1,156 +0,0 @@
|
||||
import gradio as gr
|
||||
import json
|
||||
import jsonlines
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from utils import get_datasets
|
||||
|
||||
|
||||
# see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
|
||||
gs_url = "gs://shark-datasets/portraits"
|
||||
|
||||
shark_root = Path(__file__).parent.parent
|
||||
demo_css = shark_root.joinpath("web/demo.css").resolve()
|
||||
nodlogo_loc = shark_root.joinpath(
|
||||
"web/models/stable_diffusion/logos/nod-logo.png"
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
|
||||
datasets, images = get_datasets(gs_url)
|
||||
prompt_data = dict()
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
# TODO: add multiselect dataset
|
||||
dataset = gr.Dropdown(label="Dataset", choices=datasets)
|
||||
image_name = gr.Dropdown(label="Image", choices=[])
|
||||
|
||||
with gr.Row(elem_id="ui_body", visible=True):
|
||||
# TODO: add ability to search image by typing
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
image = gr.Image(type="filepath").style(height=512)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
lines=3,
|
||||
)
|
||||
next_image = gr.Button("Next")
|
||||
finish = gr.Button("Finish")
|
||||
|
||||
def filter_datasets(dataset):
|
||||
# TODO: execute finish process when switching dataset
|
||||
if dataset is None:
|
||||
return gr.Dropdown.update(value=None, choices=[])
|
||||
|
||||
# create the dataset dir if doesn't exist and download prompt file
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
|
||||
if not os.path.exists(dataset_path):
|
||||
os.mkdir(dataset_path)
|
||||
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
|
||||
|
||||
# read prompt jsonlines file
|
||||
prompt_data.clear()
|
||||
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
|
||||
for line in reader.iter(type=dict, skip_invalid=True):
|
||||
prompt_data[line["file_name"]] = line["text"]
|
||||
|
||||
return gr.Dropdown.update(choices=images[dataset])
|
||||
|
||||
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
|
||||
|
||||
def display_image(dataset, image_name):
|
||||
if dataset is None or image_name is None:
|
||||
return gr.Image.update(value=None), gr.Textbox.update(value=None)
|
||||
|
||||
# download and load the image
|
||||
# TODO: remove previous image if change image from dropdown
|
||||
img_gs_path = gs_url + "/" + dataset + "/" + image_name
|
||||
img_sub_path = "/".join(image_name.split("/")[:-1])
|
||||
img_dst_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
|
||||
)
|
||||
if not os.path.exists(img_dst_path):
|
||||
os.mkdir(img_dst_path)
|
||||
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
|
||||
img = Image.open(img_dst_path + image_name.split("/")[-1])
|
||||
|
||||
return gr.Image.update(value=img), gr.Textbox.update(
|
||||
value=prompt_data[image_name]
|
||||
)
|
||||
|
||||
image_name.change(
|
||||
fn=display_image, inputs=[dataset, image_name], outputs=[image, prompt]
|
||||
)
|
||||
|
||||
def update_prompt(dataset, image_name, prompt):
|
||||
if dataset is None or image_name is None or prompt is None:
|
||||
return
|
||||
|
||||
prompt_data[image_name] = prompt
|
||||
prompt_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
|
||||
)
|
||||
# write prompt jsonlines file
|
||||
with open(prompt_path, "w") as f:
|
||||
for key, value in prompt_data.items():
|
||||
f.write(json.dumps({"file_name": key, "text": value}))
|
||||
f.write("\n")
|
||||
return
|
||||
|
||||
prompt.change(fn=update_prompt, inputs=[dataset, image_name, prompt])
|
||||
|
||||
def get_next_image(dataset, image_name):
|
||||
if dataset is None or image_name is None:
|
||||
return
|
||||
|
||||
# remove local image
|
||||
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
|
||||
os.system(f'rm "{img_path}"')
|
||||
# get the index for the next image
|
||||
# TODO: finish when get to the end
|
||||
idx = images[dataset].index(image_name)
|
||||
|
||||
return gr.Dropdown.update(value=images[dataset][idx + 1])
|
||||
|
||||
next_image.click(
|
||||
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
|
||||
)
|
||||
|
||||
def finish_annotation(dataset):
|
||||
if dataset is None:
|
||||
return
|
||||
|
||||
# upload prompt and remove local data
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
dataset_gs_path = gs_url + "/" + dataset + "/"
|
||||
os.system(
|
||||
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
|
||||
)
|
||||
os.system(f'rm -rf "{dataset_path}"')
|
||||
|
||||
return gr.Dropdown.update(value=None)
|
||||
|
||||
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shark_web.launch(
|
||||
share=False,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=8080,
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
# SHARK Annotator
|
||||
gradio==3.15.0
|
||||
jsonlines
|
||||
@@ -1,23 +0,0 @@
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def get_datasets(gs_url):
|
||||
datasets = set()
|
||||
images = dict()
|
||||
|
||||
storage_client = storage.Client()
|
||||
bucket_name = gs_url.split("/")[2]
|
||||
source_blob_name = "/".join(gs_url.split("/")[3:])
|
||||
blobs = storage_client.list_blobs(bucket_name, prefix=source_blob_name)
|
||||
|
||||
for blob in blobs:
|
||||
dataset_name = blob.name.split("/")[1]
|
||||
datasets.add(dataset_name)
|
||||
file_sub_path = "/".join(blob.name.split("/")[2:])
|
||||
# check if image or jsonl
|
||||
if "/" in file_sub_path:
|
||||
if dataset_name not in images.keys():
|
||||
images[dataset_name] = []
|
||||
images[dataset_name] += [file_sub_path]
|
||||
|
||||
return list(datasets), images
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
numpy==1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
|
||||
@@ -15,7 +13,7 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow==2.10.1
|
||||
tensorflow==2.10
|
||||
keras==2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
param([string]$arguments)
|
||||
|
||||
if ($arguments -eq "--update-src"){
|
||||
git pull
|
||||
}
|
||||
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
@@ -128,7 +128,6 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
|
||||
@@ -4,41 +4,6 @@
|
||||
|
||||
Follow setup instructions in the main [README.md](https://github.com/nod-ai/SHARK#readme) for regular usage.
|
||||
|
||||
|
||||
## Using other supported Stable Diffusion variants with SHARK:
|
||||
|
||||
Currently we support fine-tuned versions of Stable Diffusion such as:
|
||||
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
|
||||
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
|
||||
|
||||
use the flag `--hf_model_id=` to specify the repo-id of the model to be used.
|
||||
|
||||
```shell
|
||||
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
|
||||
```
|
||||
|
||||
## Run a custom model using a `.ckpt` file:
|
||||
* Install the following by running :-
|
||||
```shell
|
||||
pip install omegaconf safetensors pytorch_lightning
|
||||
```
|
||||
* Download a [.ckpt](https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp32.ckpt) file in case you don't have a locally generated `.ckpt` file for StableDiffusion.
|
||||
|
||||
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following (note the `hf_model_id` flag which states what the base model is from which the `.ckpt` model was fined-tuned off of) :-
|
||||
```shell
|
||||
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file" --hf_model_id="CompVis/stable-diffusion-v1-4"
|
||||
```
|
||||
* We use a combination of 3 flags to make this feature work : `import_mlir`, `ckpt_loc` and `hf_model_id`, of which `import_mlir` needs to be present. In case `ckpt_loc` is not specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you need to specify which base model's `.ckpt` you are using via `hf_model_id`.
|
||||
|
||||
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images. And in case you want to use any variants from HuggingFace then add the mapping of the variant to their base model in [variants.json](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/resources/variants.json).
|
||||
|
||||
|
||||
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Debug Commands</summary>
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
@@ -78,4 +43,14 @@ unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --function_input=@arr_0.npy --function_input=1xf16 --function_input=@arr_2.npy --function_input=@arr_3.npy --function_input=@arr_4.npy
|
||||
```
|
||||
|
||||
</details>
|
||||
## Using other supported Stable Diffusion variants with SHARK:
|
||||
|
||||
Currently we support the following fine-tuned versions of Stable Diffusion:
|
||||
- [AnythingV3](https://huggingface.co/Linaqruf/anything-v3.0)
|
||||
- [Analog Diffusion](https://huggingface.co/wavymulder/Analog-Diffusion)
|
||||
|
||||
use the flag `--variant=` to specify the model to be used.
|
||||
|
||||
```shell
|
||||
python .\shark\examples\shark_inference\stable_diffusion\main.py --variant=anythingv3 --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
|
||||
```
|
||||
|
||||
@@ -5,6 +5,7 @@ os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as T
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -16,11 +17,6 @@ from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from stable_args import args
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from model_wrappers import SharkifyStableDiffusionModel
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
@@ -42,8 +38,9 @@ if args.clear_all:
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
from utils import set_init_device_flags, disk_space_check, preprocessCKPT
|
||||
from utils import set_init_device_flags
|
||||
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
@@ -74,13 +71,27 @@ if __name__ == "__main__":
|
||||
|
||||
prompt = args.prompts
|
||||
neg_prompt = args.negative_prompts
|
||||
height = args.height
|
||||
width = args.width
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2_1":
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
num_inference_steps = args.steps # Number of denoising steps
|
||||
|
||||
# Scale for classifier-free guidance
|
||||
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
|
||||
|
||||
# Handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
seed = args.seed
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(
|
||||
seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
# TODO: Add support for batch_size > 1.
|
||||
batch_size = len(prompt)
|
||||
if batch_size != 1:
|
||||
@@ -89,28 +100,9 @@ if __name__ == "__main__":
|
||||
sys.exit("prompts and negative prompts must be of same length")
|
||||
|
||||
set_init_device_flags()
|
||||
disk_space_check(Path.cwd())
|
||||
|
||||
if not args.import_mlir:
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
|
||||
clip = get_clip()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
else:
|
||||
if ".ckpt" in args.ckpt_loc:
|
||||
preprocessCKPT()
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.precision,
|
||||
max_len=args.max_length,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
|
||||
clip = get_clip()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
if args.dump_isa:
|
||||
dump_isas(args.dispatch_benchmarks_dir)
|
||||
|
||||
@@ -120,7 +112,7 @@ if __name__ == "__main__":
|
||||
subfolder="scheduler",
|
||||
)
|
||||
cpu_scheduling = True
|
||||
if args.hf_model_id == "stabilityai/stable-diffusion-2-1":
|
||||
if args.version == "v2_1":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
|
||||
)
|
||||
@@ -130,7 +122,7 @@ if __name__ == "__main__":
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if args.hf_model_id == "stabilityai/stable-diffusion-2-1-base":
|
||||
if args.version == "v2_1base" and args.variant == "stablediffusion":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
|
||||
)
|
||||
@@ -147,166 +139,116 @@ if __name__ == "__main__":
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
for run in range(args.runs):
|
||||
# Handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
seed = args.seed
|
||||
if run >= 1 or seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(
|
||||
seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
|
||||
# create a random initial latent.
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
if run == 0:
|
||||
# Warmup phase to improve performance.
|
||||
if args.warmup_count >= 1:
|
||||
vae_warmup_input = torch.clone(latents).detach().numpy()
|
||||
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
|
||||
for i in range(args.warmup_count):
|
||||
vae("forward", (vae_warmup_input,))
|
||||
clip("forward", (clip_warmup_input,))
|
||||
# create a random initial latent.
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
# Warmup phase to improve performance.
|
||||
if args.warmup_count >= 1:
|
||||
vae_warmup_input = torch.clone(latents).detach().numpy()
|
||||
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
|
||||
for i in range(args.warmup_count):
|
||||
vae("forward", (vae_warmup_input,))
|
||||
clip("forward", (clip_warmup_input,))
|
||||
|
||||
start = time.time()
|
||||
if run == 0:
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
neg_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input = torch.cat(
|
||||
[uncond_input.input_ids, text_input.input_ids]
|
||||
)
|
||||
start = time.time()
|
||||
|
||||
clip_inf_start = time.time()
|
||||
text_embeddings = clip("forward", (text_input,))
|
||||
clip_inf_end = time.time()
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=args.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
neg_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
clip_inf_start = time.time()
|
||||
text_embeddings = clip("forward", (text_input,))
|
||||
clip_inf_end = time.time()
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler.is_scale_input_called = True
|
||||
|
||||
avg_ms = 0
|
||||
for i, t in tqdm(
|
||||
enumerate(scheduler.timesteps), disable=args.hide_steps
|
||||
):
|
||||
step_start = time.time()
|
||||
if not args.hide_steps:
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
|
||||
noise_pred = unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
latents = scheduler.step(noise_pred, t, latents)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
if not args.hide_steps:
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
if args.use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents
|
||||
avg_ms = 0
|
||||
for i, t in tqdm(enumerate(scheduler.timesteps), disable=args.hide_steps):
|
||||
step_start = time.time()
|
||||
if not args.hide_steps:
|
||||
print(f"i = {i} t = {t}", end="")
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = vae("forward", (latents_numpy,))
|
||||
vae_end = time.time()
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
|
||||
noise_pred = unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
end_profiling(profile_device)
|
||||
if args.use_base_vae:
|
||||
image = torch.from_numpy(images)
|
||||
image = (image.detach().cpu() * 255.0).numpy()
|
||||
images = image.round()
|
||||
end_time = time.time()
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
total_time = end_time - start
|
||||
|
||||
print(f"\nStats for run {run}:")
|
||||
print(f"Average step time: {avg_ms}ms/it")
|
||||
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"\nTotal image generation time: {total_time}sec")
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
|
||||
if args.output_dir is not None:
|
||||
output_path = Path(args.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
else:
|
||||
output_path = Path.cwd()
|
||||
disk_space_check(output_path, lim=5)
|
||||
for i in range(batch_size):
|
||||
json_store = {
|
||||
"prompt": args.prompts[i],
|
||||
"negative prompt": args.negative_prompts[i],
|
||||
"seed": args.seed,
|
||||
"hf_model_id": args.hf_model_id,
|
||||
"precision": args.precision,
|
||||
"steps": args.steps,
|
||||
"guidance_scale": args.guidance_scale,
|
||||
"scheduler": args.scheduler,
|
||||
}
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[i][:15])
|
||||
img_name = f"{prompt_slice}_{args.seed}_{run}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
if args.output_img_format == "jpg":
|
||||
pil_images[i].save(
|
||||
output_path / f"{img_name}.jpg",
|
||||
quality=95,
|
||||
subsampling=0,
|
||||
optimize=True,
|
||||
progressive=True,
|
||||
)
|
||||
else:
|
||||
pil_images[i].save(output_path / f"{img_name}.png", "PNG")
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"saving image as png. Supported formats png / jpg"
|
||||
)
|
||||
with open(output_path / f"{img_name}.json", "w") as f:
|
||||
f.write(json.dumps(json_store, indent=4))
|
||||
latents = scheduler.step(noise_pred, t, latents)
|
||||
step_time = time.time() - step_start
|
||||
avg_ms += step_time
|
||||
step_ms = int((step_time) * 1000)
|
||||
if not args.hide_steps:
|
||||
print(f" ({step_ms}ms)")
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
if args.use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = vae("forward", (latents_numpy,))
|
||||
vae_end = time.time()
|
||||
end_profiling(profile_device)
|
||||
if args.use_base_vae:
|
||||
image = torch.from_numpy(images)
|
||||
image = (image.detach().cpu() * 255.0).numpy()
|
||||
images = image.round()
|
||||
end_time = time.time()
|
||||
|
||||
avg_ms = 1000 * avg_ms / args.steps
|
||||
clip_inf_time = (clip_inf_end - clip_inf_start) * 1000
|
||||
vae_inf_time = (vae_end - vae_start) * 1000
|
||||
total_time = end_time - start
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
print(f"Clip Inference time (ms) = {clip_inf_time:.3f}")
|
||||
print(f"VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
print(f"\nTotal image generation time: {total_time}sec")
|
||||
|
||||
transform = T.ToPILImage()
|
||||
pil_images = [
|
||||
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
|
||||
]
|
||||
for i in range(batch_size):
|
||||
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
|
||||
|
||||
@@ -1,211 +1,285 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx, get_opt_flags
|
||||
from resources import base_models, variants
|
||||
from collections import defaultdict
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
import sys
|
||||
|
||||
model_config = {
|
||||
"v2_1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1_4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
model_clip_max_length = 64 if args.max_length == 64 else 77
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
model_clip_max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
model_clip_max_length = 64
|
||||
|
||||
model_variant = {
|
||||
"stablediffusion": "SD",
|
||||
"anythingv3": "Linaqruf/anything-v3.0",
|
||||
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"openjourney": "prompthero/openjourney",
|
||||
"analogdiffusion": "wavymulder/Analog-Diffusion",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2_1": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 96, 96),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 96, 96), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v2_1base": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1_4": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64),
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, model_clip_max_length, 768),
|
||||
torch.tensor(1).to(torch.float32),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = {
|
||||
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
|
||||
"anythingv3": "diffusers",
|
||||
"analogdiffusion": "main",
|
||||
"openjourney": "main",
|
||||
"dreamlike": "main",
|
||||
}
|
||||
|
||||
|
||||
# These shapes are parameter dependent.
|
||||
def replace_shape_str(shape, max_len, width, height):
|
||||
new_shape = []
|
||||
for i in range(len(shape)):
|
||||
if shape[i] == "max_len":
|
||||
new_shape.append(max_len)
|
||||
elif shape[i] == "height":
|
||||
new_shape.append(height)
|
||||
elif shape[i] == "width":
|
||||
new_shape.append(width)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae".
|
||||
def get_input_info(model_info, max_len, width, height):
|
||||
dtype_config = {"f32": torch.float32, "i64": torch.int64}
|
||||
input_map = defaultdict(list)
|
||||
for k in model_info:
|
||||
for inp in model_info[k]:
|
||||
shape = model_info[k][inp]["shape"]
|
||||
dtype = dtype_config[model_info[k][inp]["dtype"]]
|
||||
tensor = None
|
||||
if isinstance(shape, list):
|
||||
clean_shape = replace_shape_str(shape, max_len, width, height)
|
||||
if dtype == torch.int64:
|
||||
tensor = torch.randint(1, 3, tuple(clean_shape))
|
||||
else:
|
||||
tensor = torch.randn(*clean_shape).to(dtype)
|
||||
elif isinstance(shape, int):
|
||||
tensor = torch.tensor(shape).to(dtype)
|
||||
else:
|
||||
sys.exit("shape isn't specified correctly.")
|
||||
input_map[k].append(tensor)
|
||||
|
||||
return input_map
|
||||
|
||||
|
||||
# Returns the model configuration in a dict containing input parameters
|
||||
# for clip, unet and vae respectively.
|
||||
def get_model_configuration(model_id, max_len, width, height):
|
||||
if model_id in base_models:
|
||||
return get_input_info(base_models[model_id], max_len, width, height)
|
||||
elif model_id in variants:
|
||||
return get_input_info(
|
||||
base_models[variants[model_id]], max_len, width, height
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_variant[args.variant],
|
||||
subfolder="text_encoder",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
else:
|
||||
sys.exit(
|
||||
"The model info is not configured, please add the model_configuration in base_model.json if it's a base model, else add it in the variant.json"
|
||||
)
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input[args.version]["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
custom_weights: str,
|
||||
precision: str,
|
||||
max_len: int = 64,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
use_base_vae: bool = False,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.inputs = get_model_configuration(
|
||||
model_id, max_len, width // 8, height // 8
|
||||
)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
str(max_len)
|
||||
+ "_"
|
||||
+ str(height)
|
||||
+ "_"
|
||||
+ str(width)
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
# We need a better naming convention for the .vmfbs because despite
|
||||
# using the custom model variant the .vmfb names remain the same and
|
||||
# it'll always pick up the compiled .vmfb instead of compiling the
|
||||
# custom model.
|
||||
# So, currently, we add `self.model_id` in the `self.model_name` of
|
||||
# .vmfb file.
|
||||
# TODO: Have a better way of naming the vmfbs using self.model_name.
|
||||
import re
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class BaseVaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
model_name = re.sub(r"\W+", "_", self.model_id)
|
||||
if model_name[0] == "_":
|
||||
model_name = model_name[1:]
|
||||
self.model_name = self.model_name + "_" + model_name
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
sys.exit("please specify max_len in the range [32, 77].")
|
||||
if not (width % 8 == 0 and width >= 384):
|
||||
sys.exit("width should be greater than 384 and multiple of 8")
|
||||
if not (height % 8 == 0 and height >= 384):
|
||||
sys.exit("height should be greater than 384 and multiple of 8")
|
||||
vae = BaseVaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
self.base_vae = base_vae
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def forward(self, input):
|
||||
if not self.base_vae:
|
||||
input = 1 / 0.18215 * input
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
x = (x / 2 + 0.5).clamp(0, 1)
|
||||
if self.base_vae:
|
||||
return x
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
vae_name = "base_vae" if self.base_vae else "vae"
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
model_name=vae_name + self.model_name,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def get_unet(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
def forward(self, input):
|
||||
input = 1 / 0.18215 * input
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
x = (x / 2 + 0.5).clamp(0, 1)
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
vae = VaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
unet = UnetModel()
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name="unet" + self.model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
)
|
||||
return shark_unet
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
clip_model = CLIPText()
|
||||
def forward(self, latent, timestep, text_embedding, guidance_scale):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents, timestep, text_embedding, return_dict=False
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name="clip" + self.model_name,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
def __call__(self):
|
||||
compiled_clip = self.get_clip()
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
unet = UnetModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_base_vae_mlir,
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from resources import models_db
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
@@ -7,18 +13,6 @@ BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
hf_model_variant_map = {
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
|
||||
"prompthero/openjourney": ["openjourney", "v2_1base"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
}
|
||||
|
||||
variant, version = hf_model_variant_map[args.hf_model_id]
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
iree_flags = []
|
||||
@@ -39,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket_key}/{model_key} is not present in the models database"
|
||||
f"{bucket}/{model_key} is not present in the models database"
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -68,16 +62,13 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
def get_unet():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
@@ -85,25 +76,24 @@ def get_vae():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
if not args.use_tuned and args.import_mlir:
|
||||
if args.use_base_vae:
|
||||
return get_base_vae_mlir(model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
bucket_key = f"{variant}/untuned"
|
||||
model_key = (
|
||||
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
)
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
model_key = f"{args.variant}/{args.version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "clip", "untuned", "fp32"
|
||||
)
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
@@ -11,32 +11,21 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_json_file(path):
|
||||
json_var = []
|
||||
loc_json = resource_path(path)
|
||||
if os.path.exists(loc_json):
|
||||
with open(loc_json, encoding="utf-8") as fopen:
|
||||
json_var = json.load(fopen)
|
||||
prompt_examples = []
|
||||
prompts_loc = resource_path("resources/prompts.json")
|
||||
if os.path.exists(prompts_loc):
|
||||
with open(prompts_loc, encoding="utf-8") as fopen:
|
||||
prompt_examples = json.load(fopen)
|
||||
|
||||
if not json_var:
|
||||
print(f"Unable to fetch {path}")
|
||||
|
||||
return json_var
|
||||
if not prompt_examples:
|
||||
print("Unable to fetch prompt examples.")
|
||||
|
||||
|
||||
# TODO: This shouldn't be called from here, every time the file imports
|
||||
# it will run all the global vars.
|
||||
prompts_examples = get_json_file("resources/prompts.json")
|
||||
models_db = get_json_file("resources/model_db.json")
|
||||
models_db = []
|
||||
models_loc = resource_path("resources/model_db.json")
|
||||
if os.path.exists(models_loc):
|
||||
with open(models_loc, encoding="utf-8") as fopen:
|
||||
models_db = json.load(fopen)
|
||||
|
||||
# The base_model contains the input configuration for the different
|
||||
# models and also helps in providing information for the variants.
|
||||
base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# The variant contains the mapping from variant to the base configuration
|
||||
# to get the required inputs.
|
||||
# If the input configuration doesn't match it should be registered standalone in the base configuration.
|
||||
variants = get_json_file("resources/variants.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
if len(models_db) != 3:
|
||||
sys.exit("Error: Unable to load models database.")
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
{
|
||||
"stabilityai/stable-diffusion-2-1": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
1,
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
2,
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
1,4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
2,
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
1,
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
2,
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
1,4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
2,
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
@@ -2,13 +2,10 @@
|
||||
{
|
||||
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
|
||||
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
|
||||
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
|
||||
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
|
||||
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
|
||||
"openjourney/tuned":"gs://shark_tank/sd_tuned",
|
||||
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
|
||||
@@ -16,26 +13,20 @@
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
|
||||
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
|
||||
@@ -44,22 +35,18 @@
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip2_18dec_fp32",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
|
||||
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
|
||||
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
|
||||
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
|
||||
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
|
||||
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
{
|
||||
"unet": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": []
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": []
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32"
|
||||
],
|
||||
"specified_compilation_flags": {
|
||||
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
|
||||
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
|
||||
}
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
{
|
||||
"runwayml/stable-diffusion-v1-5": "CompVis/stable-diffusion-v1-4",
|
||||
"prompthero/openjourney": "CompVis/stable-diffusion-v1-4",
|
||||
"Linaqruf/anything-v3.0": "CompVis/stable-diffusion-v1-4",
|
||||
"stabilityai/stable-diffusion-2-1-base": "stabilityai/stable-diffusion-2-1",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": "CompVis/stable-diffusion-v1-4",
|
||||
"eimiss/EimisAnimeDiffusion_1.0v": "CompVis/stable-diffusion-v1-4",
|
||||
"claudfuen/photorealistic-fuen-v1": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/Nitro-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"stabilityai/stable-diffusion-2-base": "stabilityai/stable-diffusion-2-1",
|
||||
"wavymulder/Analog-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/redshift-diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"wavymulder/portraitplus": "CompVis/stable-diffusion-v1-4",
|
||||
"Linaqruf/anything-v3-better-vae": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/Arcane-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"hakurei/waifu-diffusion": "stabilityai/stable-diffusion-2-1",
|
||||
"lambdalabs/sd-pokemon-diffusers": "CompVis/stable-diffusion-v1-4",
|
||||
"prompthero/openjourney-v2": "CompVis/stable-diffusion-v1-4",
|
||||
"andite/anything-v4.0": "CompVis/stable-diffusion-v1-4"
|
||||
}
|
||||
@@ -17,8 +17,8 @@ SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(1, 4, args.height // 8, args.width // 8),
|
||||
"output": torch.randn(1, 4, args.height // 8, args.width // 8),
|
||||
"latent": torch.randn(1, 4, 64, 64),
|
||||
"output": torch.randn(1, 4, 64, 64),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
@@ -84,8 +84,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name=f"euler_scale_model_input_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
model_name="euler_scale_model_input_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
@@ -93,8 +92,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name=f"euler_step_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
model_name="euler_step_" + args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -12,14 +12,11 @@ from opt_params import get_params
|
||||
from utils import set_init_device_flags
|
||||
|
||||
|
||||
set_init_device_flags()
|
||||
device = (
|
||||
args.device if "://" not in args.device else args.device.split("://")[0]
|
||||
)
|
||||
|
||||
# Downloads the model (Unet or VAE fp16) from shark_tank
|
||||
set_init_device_flags()
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
use_winograd = True
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
@@ -37,32 +34,29 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
|
||||
# Downloads the tuned config files from shark_tank
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
if args.use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
if use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
|
||||
if args.annotation_model == "unet" or device == "cuda":
|
||||
if args.annotation_model == "unet":
|
||||
if args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
args.version = "v1_4"
|
||||
if args.annotation_model == "vae":
|
||||
args.max_length = 77
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
if args.use_winograd:
|
||||
if use_winograd:
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_model,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=args.use_winograd,
|
||||
winograd=use_winograd,
|
||||
)
|
||||
with open(
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
|
||||
@@ -70,8 +64,8 @@ if args.use_winograd:
|
||||
f.write(str(winograd_model))
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
if args.annotation_model == "unet" or device == "cuda":
|
||||
if args.use_winograd:
|
||||
if args.annotation_model == "unet":
|
||||
if use_winograd:
|
||||
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
else:
|
||||
@@ -79,22 +73,11 @@ if args.annotation_model == "unet" or device == "cuda":
|
||||
dump_after = "iree-flow-pad-linalg-ops"
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
device_spec_args = ""
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args += flag + " "
|
||||
elif device == "vulkan":
|
||||
device_spec_args = (
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(device)} "
|
||||
f"{device_spec_args}"
|
||||
f"--iree-hal-target-backends={iree_target_map(args.device)} "
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
"--iree-flow-enable-padding-linalg-ops "
|
||||
|
||||
@@ -42,20 +42,6 @@ p.add_argument(
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
@@ -78,13 +64,20 @@ p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v2_1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=True,
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
@@ -117,6 +110,12 @@ p.add_argument(
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--variant",
|
||||
default="stablediffusion",
|
||||
help="We now support multiple vairants of SD finetuned for different dataset. you can use the following anythingv3, ...", # TODO add more once supported
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
@@ -124,41 +123,6 @@ p.add_argument(
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="specify the format in which output image is save. Supported options: jpg / png",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--runs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to be generated with random seeds in single execution",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -283,11 +247,4 @@ p.add_argument(
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_winograd",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Apply Winograd on selected conv ops.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -21,7 +21,7 @@ KNOWN ISSUES with this special AMD driver:
|
||||
|
||||
## Installation
|
||||
|
||||
Download the latest Windows SHARK SD binary [455 here](https://storage.googleapis.com/shark-public/windows/shark_sd_20230120_455.exe) in a folder of your choice. If you want nighly builds you can look for them in the github releases page.
|
||||
Download the latest Windows SHARK SD binary [423 here](https://github.com/nod-ai/SHARK/releases/download/20230101.423/shark_sd_20230101_423.exe) in a folder of your choice. If you want nighly builds you can look for them in the github releases page. Please read carefully the following notes:
|
||||
|
||||
Notes:
|
||||
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR, that can get outdated if you run multiple EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
|
||||
@@ -60,8 +60,8 @@ Here are some samples generated:
|
||||
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
<details>
|
||||
<summary> Windows 10/11 Users </summary>
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
@@ -78,10 +78,8 @@ git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Linux</summary>
|
||||
### Linux
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
@@ -89,49 +87,33 @@ cd SHARK
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
</details>
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
<details>
|
||||
<summary>Windows 10/11 Users</summary>
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\Users\nod\SHARK> cd web
|
||||
(shark.venv) PS C:\Users\nod\SHARK\web> python index.py
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Linux Users</summary>
|
||||
|
||||
#### Linux Users
|
||||
```shell
|
||||
(shark.venv) > cd web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
<details>
|
||||
<summary>Windows 10/11 Users</summary>
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\shark\examples\shark_inference\stable_diffusion\main.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Linux</summary>
|
||||
|
||||
#### Linux
|
||||
```shell
|
||||
python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
The output on a 6900XT would like:
|
||||
|
||||
@@ -148,10 +130,10 @@ Total image generation runtime (s): 10.390909433364868
|
||||
(shark.venv) PS C:\g\shark>
|
||||
```
|
||||
|
||||
|
||||
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<details>
|
||||
<summary>Discord link</summary>
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
</details>
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
|
||||
Then create in the directory web file .env
|
||||
In it the record:
|
||||
TG_TOKEN="your_token"
|
||||
specifying your bot's token from previous step.
|
||||
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
|
||||
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
|
||||
|
||||
Bot commands:
|
||||
/select_model
|
||||
/select_scheduler
|
||||
/set_steps "integer number of steps"
|
||||
/set_guidance_scale "integer number"
|
||||
/set_negative_prompt "negative text"
|
||||
Any other text triggers the creation of an image based on it.
|
||||
@@ -7,9 +7,6 @@ from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from resources import opt_flags
|
||||
import sys
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -49,8 +46,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
@@ -64,18 +59,10 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
model_name,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
extra_args=[],
|
||||
):
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
@@ -193,49 +180,27 @@ def set_init_device_flags():
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.hf_model_id in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]:
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.hf_model_id == "prompthero/openjourney":
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.hf_model_id
|
||||
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]:
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
|
||||
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
|
||||
if (
|
||||
args.hf_model_id
|
||||
in [
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"Linaqruf/anything-v3.0",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
]
|
||||
and args.precision == "fp16"
|
||||
and "cuda" in args.device
|
||||
and get_cuda_sm_cc() == "sm_80"
|
||||
):
|
||||
args.use_tuned = True
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
@@ -264,83 +229,3 @@ def get_available_devices():
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
|
||||
|
||||
def disk_space_check(path, lim=20):
|
||||
from shutil import disk_usage
|
||||
|
||||
du = disk_usage(path)
|
||||
free = du.free / (1024 * 1024 * 1024)
|
||||
if free <= lim:
|
||||
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
|
||||
|
||||
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else args.device.split("://")[0]
|
||||
)
|
||||
if (
|
||||
device
|
||||
not in opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
]
|
||||
):
|
||||
device = "default_device"
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
|
||||
return iree_flags
|
||||
|
||||
|
||||
def preprocessCKPT():
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(args.ckpt_loc)
|
||||
diffusers_path = path.parent.absolute()
|
||||
diffusers_directory_name = path.stem
|
||||
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
||||
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
||||
print(
|
||||
"Created directory : ",
|
||||
diffusers_directory_name,
|
||||
" at -> ",
|
||||
diffusers_path,
|
||||
)
|
||||
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
||||
# TODO: Use the SD to Diffusers CKPT pipeline once it's included in the release.
|
||||
sd_to_diffusers = os.path.join(os.getcwd(), "sd_to_diffusers.py")
|
||||
if not os.path.isfile(sd_to_diffusers):
|
||||
url = "https://raw.githubusercontent.com/huggingface/diffusers/8a3f0c1f7178f4a3d5a5b21ae8c2906f473e240d/scripts/convert_original_stable_diffusion_to_diffusers.py"
|
||||
import requests
|
||||
|
||||
req = requests.get(url)
|
||||
open(sd_to_diffusers, "wb").write(req.content)
|
||||
print("Downloaded SD to Diffusers converter")
|
||||
else:
|
||||
print("SD to Diffusers converter already exists")
|
||||
|
||||
os.system(
|
||||
"python "
|
||||
+ sd_to_diffusers
|
||||
+ " --checkpoint_path="
|
||||
+ args.ckpt_loc
|
||||
+ " --dump_path="
|
||||
+ path_to_diffusers
|
||||
)
|
||||
args.ckpt_loc = path_to_diffusers
|
||||
print("Custom model path is : ", args.ckpt_loc)
|
||||
|
||||
@@ -9,9 +9,9 @@ model_input = {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 128, 128),),
|
||||
"unet": (
|
||||
torch.randn(2, 7, 128, 128), # latents
|
||||
torch.randn(2, 7, 128, 128).half(), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.randn(2, 77, 1024).half(), # embedding
|
||||
torch.randn(2).to(torch.int64), # noise_level
|
||||
),
|
||||
}
|
||||
@@ -72,6 +72,7 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
revision="fp16",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -87,13 +88,12 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel()
|
||||
f16_input_mask = (True, True, True, False)
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple([inputs.cuda() for inputs in model_input["unet"]])
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
model_input["unet"],
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -59,15 +59,12 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(
|
||||
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
||||
):
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
"hello",
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from shark.shark_runner import SharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
@@ -42,7 +42,6 @@ def forward(params, buffers, args):
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
shark_module.compile(forward)
|
||||
shark_module = SharkTrainer(mod, inp, custom_inference_fn=forward)
|
||||
|
||||
print(shark_module.train())
|
||||
print(shark_module.forward())
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -30,16 +29,25 @@ def get_cpu_count():
|
||||
|
||||
# Get the default cpu args.
|
||||
def get_iree_cpu_args():
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = (
|
||||
subprocess.run(
|
||||
find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
.stdout.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
if os_name == "Darwin":
|
||||
kernel_version = uname.release
|
||||
find_kernel_version_cmd = "uname -r"
|
||||
kernel_version = subprocess.run(
|
||||
find_kernel_version_cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
).stdout.decode("utf-8")
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
elif os_name == "Windows":
|
||||
target_triple = "x86_64-pc-windows-msvc"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
|
||||
@@ -1,470 +0,0 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
# get version
|
||||
version = get_version(triple=triple)
|
||||
# TODO get revision
|
||||
revision = 120
|
||||
|
||||
# extensions
|
||||
extensions = get_extensions(triple)
|
||||
# get vendor
|
||||
vendor = get_vendor(triple)
|
||||
# get device type
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
def get_vulkan_target_env_flag(vulkan_target_triple):
|
||||
|
||||
target_env = get_vulkan_target_env(vulkan_target_triple)
|
||||
target_env_flag = f"--iree-vulkan-target-env={target_env}"
|
||||
return target_env_flag
|
||||
|
||||
|
||||
def get_version(triple):
|
||||
arch, product, os = triple
|
||||
if os in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if product in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if arch in ["unknown"]:
|
||||
return "v1.1"
|
||||
return "v1.3"
|
||||
|
||||
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("VK_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["VK_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
def get_vendor(triple):
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
|
||||
return "AMD"
|
||||
if arch == "valhall":
|
||||
return "ARM"
|
||||
if arch == "m1":
|
||||
return "Apple"
|
||||
if arch in ["turing", "ampere"]:
|
||||
return "NVIDIA"
|
||||
if arch == "ardeno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
return "SwiftShader"
|
||||
return "Unknown"
|
||||
print(f"Vendor for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch == "cpu":
|
||||
return "CPU"
|
||||
if arch in ["turing", "ampere"]:
|
||||
return "DiscreteGPU"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
|
||||
if product == "ivega10":
|
||||
return "IntegratedGPU"
|
||||
return "DiscreteGPU"
|
||||
if arch in ["m1", "valhall", "adreno"]:
|
||||
return "IntegratedGPU"
|
||||
print(f"Device type for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
|
||||
cap = OrderedDict()
|
||||
arch, product, os = triple
|
||||
subgroup_feature = {
|
||||
"Basic": 1,
|
||||
"Vote": 2,
|
||||
"Arithmetic": 4,
|
||||
"Ballot": 8,
|
||||
"Shuffle": 16,
|
||||
"ShuffleRelative": 32,
|
||||
"Clustered": 64,
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["maxComputeWorkGroupInvocations"] = 128
|
||||
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["minSubgroupSize"] = None
|
||||
cap["maxSubgroupSize"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
cap["shaderInt16"] = False
|
||||
cap["shaderInt64"] = False
|
||||
cap["storageBuffer16BitAccess"] = False
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = False
|
||||
cap["storageBuffer8BitAccess"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = False
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["minSubgroupSize"] = 64
|
||||
cap["maxSubgroupSize"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 512
|
||||
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
|
||||
|
||||
cap["subgroupSize"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
if os == "android31":
|
||||
cap["subgroupFeatures"].append("Shuffle")
|
||||
cap["subgroupFeatures"].append("ShuffleRelative")
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["subgroupSize"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
]
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "andorid31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroupSize"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
print(
|
||||
f"Architecture {arch} not matched. Using default vulkan target device capability"
|
||||
)
|
||||
|
||||
def get_comma_sep_str(ele_list):
|
||||
l = ""
|
||||
for ele in ele_list:
|
||||
l += f"{ele}, "
|
||||
l = f"[{l[:-2]}]"
|
||||
return l
|
||||
|
||||
res = ""
|
||||
for k, v in cap.items():
|
||||
|
||||
if v is None or v == False:
|
||||
continue
|
||||
if isinstance(v, bool):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "maxComputeWorkGroupSize":
|
||||
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
res += f"{k} = {v}, "
|
||||
res = res[:-2]
|
||||
return res
|
||||
@@ -18,7 +18,6 @@ from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
@@ -71,8 +70,6 @@ def get_vulkan_target_triple(device_name):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4080")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4000")):
|
||||
triple = f"turing-rtx4000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "5000")):
|
||||
@@ -91,9 +88,7 @@ def get_vulkan_target_triple(device_name):
|
||||
triple = f"pascal-gtx1080-{system_os}"
|
||||
|
||||
# Amd Targets
|
||||
# Linux: Radeon RX 7900 XTX
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
elif all(x in device_name for x in ("AMD", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
@@ -102,16 +97,15 @@ def get_vulkan_target_triple(device_name):
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
vulkan_device = (
|
||||
device_name if device_name is not None else get_vulkan_device_name()
|
||||
)
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
@@ -128,23 +122,11 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
vulkan_triple_flag = arg
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
res_vulkan_flag.append(vulkan_target_env)
|
||||
return res_vulkan_flag
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
|
||||
@@ -23,8 +23,6 @@ from datetime import datetime
|
||||
import time
|
||||
import csv
|
||||
import os
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
@@ -67,7 +65,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.enable_tf32 = shark_args.enable_tf32
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
@@ -110,8 +107,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
@@ -119,7 +114,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model = dynamo.optimize("inductor")(frontend_model)
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
|
||||
@@ -245,119 +245,8 @@ class SharkImporter:
|
||||
)
|
||||
|
||||
|
||||
def get_f16_inputs(inputs, is_f16, f16_input_mask):
|
||||
|
||||
if is_f16 == False:
|
||||
return inputs
|
||||
if f16_input_mask == None:
|
||||
return tuple([x.half() for x in inputs])
|
||||
|
||||
f16_masked_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
if f16_input_mask[i]:
|
||||
f16_masked_inputs.append(inputs[i].half())
|
||||
else:
|
||||
f16_masked_inputs.append(inputs[i])
|
||||
|
||||
return tuple(f16_masked_inputs)
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
import torch
|
||||
|
||||
kwargs_dict = {
|
||||
"dtype": torch.float16,
|
||||
"device": torch.device(type="cpu"),
|
||||
"pin_memory": False,
|
||||
}
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.var_mean]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node, node.args[1])
|
||||
if node.name.startswith("getitem"):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
if node.args[0].target in [torch.ops.aten.var_mean]:
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def flatten_training_input(inputs):
|
||||
flattened_input = []
|
||||
for i in inputs:
|
||||
if isinstance(i, dict):
|
||||
for value in i.values():
|
||||
flattened_input.append(value.detach())
|
||||
elif isinstance(i, tuple):
|
||||
for value in i:
|
||||
flattened_input.append(value)
|
||||
else:
|
||||
flattened_input.append(i)
|
||||
return tuple(flattened_input)
|
||||
|
||||
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(
|
||||
model,
|
||||
inputs,
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
debug=False,
|
||||
training=False,
|
||||
):
|
||||
def import_with_fx(model, inputs, debug=False):
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
@@ -397,26 +286,16 @@ def import_with_fx(
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
if is_f16:
|
||||
fx_g = fx_g.half()
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
|
||||
if training:
|
||||
change_fx_graph_return_to_tuple(fx_g)
|
||||
inputs = flatten_training_input(inputs)
|
||||
|
||||
ts_graph = torch.jit.script(fx_g)
|
||||
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)
|
||||
mlir_importer = SharkImporter(
|
||||
ts_graph,
|
||||
fx_g,
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
if debug and not is_f16:
|
||||
if debug:
|
||||
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
|
||||
return mlir_module, func_name
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir()
|
||||
|
||||
return mlir_module, func_name
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.backward_makefx import MakeFxModule
|
||||
from shark.shark_importer import import_with_fx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
@@ -68,21 +67,23 @@ class SharkTrainer:
|
||||
self.frontend = frontend
|
||||
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None, extra_args=[]):
|
||||
def compile(self, training_fn=None):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
packed_inputs = (
|
||||
dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()),
|
||||
tuple(self.input),
|
||||
)
|
||||
mlir_module, func_name = import_with_fx(
|
||||
training_fn, packed_inputs, False, [], training=True
|
||||
aot_module = MakeFxModule(
|
||||
self.model, tuple(self.input), custom_inference_fn=training_fn
|
||||
)
|
||||
aot_module.generate_graph()
|
||||
# Returns the backward graph.
|
||||
training_graph = aot_module.training_graph
|
||||
weights = self.get_torch_params()
|
||||
self.shark_runner = SharkRunner(
|
||||
mlir_module,
|
||||
training_graph,
|
||||
weights + self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
"tm_tensor",
|
||||
extra_args=extra_args,
|
||||
self.jit_trace,
|
||||
self.from_aot,
|
||||
self.frontend,
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
self.shark_runner = SharkRunner(
|
||||
@@ -111,8 +112,8 @@ class SharkTrainer:
|
||||
params = [x.numpy() for x in params]
|
||||
print(f"Training started for {num_iters} iterations:")
|
||||
for i in tqdm(range(num_iters)):
|
||||
params = self.shark_runner.run(
|
||||
"forward", params + self.input, self.frontend
|
||||
params = self.shark_runner.forward(
|
||||
params + self.input, self.frontend
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
67
web/demo.css
67
web/demo.css
@@ -1,67 +0,0 @@
|
||||
.gradio-container {
|
||||
background-color: black
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: 20px !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: 10px !important;
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
background-color: transparent;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#demo_title {
|
||||
background-color: black;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
padding-top: 50px;
|
||||
padding-bottom: 0px;
|
||||
width: 460px !important;
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
#prompt_box_outer div:first-child {
|
||||
border-radius: 0 !important
|
||||
}
|
||||
|
||||
#prompt_box textarea {
|
||||
background-color: #1d1d1d !important
|
||||
}
|
||||
|
||||
#prompt_examples {
|
||||
margin: 0 !important
|
||||
}
|
||||
|
||||
#prompt_examples svg {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.gr-sample-textbox {
|
||||
border-radius: 1rem !important;
|
||||
border-color: rgb(31, 41, 55) !important;
|
||||
border-width: 2px !important;
|
||||
}
|
||||
|
||||
#ui_body {
|
||||
background-color: #111111 !important;
|
||||
padding: 10px !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
|
||||
#img_result+div {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
29
web/index.py
29
web/index.py
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
import gradio as gr
|
||||
@@ -13,7 +12,26 @@ nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
|
||||
|
||||
|
||||
demo_css = Path(__file__).parent.joinpath("demo.css").resolve()
|
||||
demo_css = """
|
||||
.gradio-container {background-color: black}
|
||||
.container {background-color: black !important; padding-top:20px !important; }
|
||||
#ui_title {padding: 10px !important; }
|
||||
#top_logo {background-color: transparent; border-radius: 0 !important; border: 0; }
|
||||
#demo_title {background-color: black; border-radius: 0 !important; border: 0; padding-top: 50px; padding-bottom: 0px; width: 460px !important;}
|
||||
|
||||
#demo_title_outer {border-radius: 0; }
|
||||
#prompt_box_outer div:first-child {border-radius: 0 !important}
|
||||
#prompt_box textarea {background-color:#1d1d1d !important}
|
||||
#prompt_examples {margin:0 !important}
|
||||
#prompt_examples svg {display: none !important;}
|
||||
|
||||
.gr-sample-textbox { border-radius: 1rem !important; border-color: rgb(31,41,55) !important; border-width:2px !important; }
|
||||
#ui_body {background-color: #111111 !important; padding: 10px !important; border-radius: 0.5em !important;}
|
||||
|
||||
#img_result+div {display: none !important;}
|
||||
|
||||
footer {display: none !important;}
|
||||
"""
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
@@ -123,13 +141,6 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
lines=4,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
prompt.submit(
|
||||
stable_diff_inf,
|
||||
|
||||
@@ -16,7 +16,6 @@ from models.stable_diffusion.stable_args import args
|
||||
from models.stable_diffusion.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
import gc
|
||||
|
||||
|
||||
model_config = {
|
||||
@@ -82,25 +81,16 @@ class ModelCache:
|
||||
self.version = None
|
||||
self.schedulers = None
|
||||
self.tokenizer = None
|
||||
self.vae = None
|
||||
self.clip = None
|
||||
self.unet = None
|
||||
|
||||
def set_models(self, device_key):
|
||||
if self.device != device_key or self.variant != args.variant:
|
||||
self.device = device_key
|
||||
self.variant = args.variant
|
||||
self.version = args.version
|
||||
args.device = device_key.split("=>", 1)[1].strip()
|
||||
args.device = device_key.split("=>", 1)[0].strip()
|
||||
args.max_length = 64
|
||||
args.use_tuned = True
|
||||
set_init_device_flags()
|
||||
del self.schedulers
|
||||
del self.tokenizer
|
||||
del self.vae
|
||||
del self.unet
|
||||
del self.clip
|
||||
gc.collect()
|
||||
self.schedulers = get_schedulers(args.version)
|
||||
self.tokenizer = get_tokenizer(args.version)
|
||||
self.vae = get_vae()
|
||||
|
||||
@@ -5,15 +5,10 @@ import torchvision.transforms as T
|
||||
from tqdm.auto import tqdm
|
||||
from models.stable_diffusion.cache_objects import model_cache
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from models.stable_diffusion.utils import disk_space_check
|
||||
from random import randint
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
if args.clear_all:
|
||||
@@ -70,55 +65,6 @@ def set_ui_params(
|
||||
args.variant = variant
|
||||
|
||||
|
||||
# save output images and the inputs correspoding to it.
|
||||
def save_output_img(output_img):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
disk_space_check(output_path, lim=5)
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_history.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(
|
||||
out_img_path,
|
||||
quality=95,
|
||||
subsampling=0,
|
||||
optimize=True,
|
||||
progressive=True,
|
||||
)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
output_img.save(out_img_path, "PNG")
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"saving image as png. Supported formats png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": args.variant,
|
||||
"VERSION": args.version,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": args.seed,
|
||||
"CFG_SCALE": float(args.guidance_scale),
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -157,7 +103,6 @@ def stable_diff_inf(
|
||||
width = 768
|
||||
|
||||
# get all cached data.
|
||||
disk_space_check(Path.cwd())
|
||||
model_cache.set_models(device_key)
|
||||
tokenizer = model_cache.tokenizer
|
||||
scheduler = model_cache.schedulers[args.scheduler]
|
||||
@@ -275,12 +220,9 @@ def stable_diff_inf(
|
||||
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nvariant={args.variant}, version={args.version}, scheduler={args.scheduler}"
|
||||
text_output += f"\ndevice={device_key}"
|
||||
text_output += f"\nvariant={args.variant}, scheduler={args.scheduler}, device={device_key}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={height}x{width}"
|
||||
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(pil_images[0])
|
||||
|
||||
return pil_images[0], text_output
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from models.stable_diffusion.utils import compile_through_fx
|
||||
from models.stable_diffusion.resources import models_config
|
||||
from models.stable_diffusion.stable_args import args
|
||||
import torch
|
||||
|
||||
model_config = {
|
||||
"v2_1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1_4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
model_clip_max_length = 64 if args.max_length == 64 else 77
|
||||
@@ -13,6 +17,14 @@ if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
elif args.variant == "openjourney":
|
||||
model_clip_max_length = 64
|
||||
|
||||
model_variant = {
|
||||
"stablediffusion": "SD",
|
||||
"anythingv3": "Linaqruf/anything-v3.0",
|
||||
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"openjourney": "prompthero/openjourney",
|
||||
"analogdiffusion": "wavymulder/Analog-Diffusion",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2_1": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
@@ -46,34 +58,45 @@ model_input = {
|
||||
},
|
||||
}
|
||||
|
||||
version = args.version if args.variant == "stablediffusion" else "v1_4"
|
||||
|
||||
|
||||
def get_configs():
|
||||
model_id_key = f"{args.variant}/{version}"
|
||||
revision_key = f"{args.variant}/{args.precision}"
|
||||
try:
|
||||
model_id = models_config[0][model_id_key]
|
||||
revision = models_config[1][revision_key]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"No entry for {model_id_key} or {revision_key} in the models configuration"
|
||||
)
|
||||
|
||||
return model_id, revision
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = {
|
||||
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
|
||||
"anythingv3": "diffusers",
|
||||
"analogdiffusion": "main",
|
||||
"openjourney": "main",
|
||||
"dreamlike": "main",
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_variant[args.variant],
|
||||
subfolder="text_encoder",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -81,44 +104,23 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input[version]["clip"],
|
||||
model_input[args.version]["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_shark_module(model_key, module, model_name, extra_args):
|
||||
if args.precision == "fp16":
|
||||
module = module.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[version][model_key]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[version][model_key]
|
||||
|
||||
shark_module = compile_through_fx(
|
||||
module,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class BaseVaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -126,19 +128,52 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = BaseVaeModel()
|
||||
return get_shark_module("vae", vae, model_name, extra_args)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -149,19 +184,52 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
return get_shark_module("vae", vae, model_name, extra_args)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -179,4 +247,39 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
return get_shark_module("unet", unet, model_name, extra_args)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f" there is no entry for {model_key} in the models database"
|
||||
f"{bucket}/{model_key} is not present in the models database"
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -29,13 +29,3 @@ if os.path.exists(models_loc):
|
||||
|
||||
if len(models_db) != 3:
|
||||
sys.exit("Error: Unable to load models database.")
|
||||
|
||||
|
||||
models_config = []
|
||||
modelconfig_loc = resource_path("resources/model_config.json")
|
||||
if os.path.exists(modelconfig_loc):
|
||||
with open(modelconfig_loc, encoding="utf-8") as fopen:
|
||||
models_config = json.load(fopen)
|
||||
|
||||
if len(models_config) != 2:
|
||||
sys.exit("Error: Unable to load models configuration.")
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
@@ -12,6 +12,7 @@
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
|
||||
@@ -117,20 +117,6 @@ p.add_argument(
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="specify the format in which output image is save. Supported options: jpg / png",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -180,9 +180,7 @@ def set_init_device_flags():
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.version == "v1_4":
|
||||
args.max_length = 77
|
||||
elif args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
@@ -191,7 +189,6 @@ def set_init_device_flags():
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or args.version == "v1_4"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
@@ -220,7 +217,7 @@ def get_available_devices():
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{device['name']} => {driver_name}://{i}")
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
@@ -230,14 +227,5 @@ def get_available_devices():
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
# available_devices.append("cpu")
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
|
||||
|
||||
def disk_space_check(path, lim=20):
|
||||
from shutil import disk_usage
|
||||
|
||||
du = disk_usage(path)
|
||||
free = du.free / (1024 * 1024 * 1024)
|
||||
if free <= lim:
|
||||
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
|
||||
|
||||
@@ -26,10 +26,8 @@ datas += collect_data_files('shark')
|
||||
datas += [
|
||||
( 'models/stable_diffusion/resources/prompts.json', 'resources' ),
|
||||
( 'models/stable_diffusion/resources/model_db.json', 'resources' ),
|
||||
( 'models/stable_diffusion/resources/model_config.json', 'resources' ),
|
||||
( 'models/stable_diffusion/logos/*', 'logos' )
|
||||
]
|
||||
datas += [('demo.css', '.')]
|
||||
|
||||
binaries = []
|
||||
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from models.stable_diffusion.main import stable_diff_inf
|
||||
from models.stable_diffusion.utils import get_available_devices
|
||||
from dotenv import load_dotenv
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram import BotCommand
|
||||
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
|
||||
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
|
||||
from io import BytesIO
|
||||
import random
|
||||
|
||||
log = logging.getLogger("TG.Bot")
|
||||
logging.basicConfig()
|
||||
log.warning("Start")
|
||||
load_dotenv()
|
||||
os.environ["AMD_ENABLE_LLPC"] = "0"
|
||||
TG_TOKEN = os.getenv("TG_TOKEN")
|
||||
SELECTED_MODEL = "stablediffusion"
|
||||
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
|
||||
STEPS = 30
|
||||
NEGATIVE_PROMPT = (
|
||||
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
|
||||
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
|
||||
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
|
||||
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
|
||||
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
|
||||
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
|
||||
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
|
||||
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
|
||||
" legs, fused fingers, too many fingers"
|
||||
)
|
||||
GUIDANCE_SCALE = 6
|
||||
available_devices = get_available_devices()
|
||||
models_list = [
|
||||
"stablediffusion",
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]
|
||||
sheds_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
|
||||
def image_to_bytes(image):
|
||||
bio = BytesIO()
|
||||
bio.name = "image.jpeg"
|
||||
image.save(bio, "JPEG")
|
||||
bio.seek(0)
|
||||
return bio
|
||||
|
||||
|
||||
def get_try_again_markup():
|
||||
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
return reply_markup
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
seed = random.randint(1, 10000)
|
||||
log.warning(SELECTED_MODEL)
|
||||
log.warning(STEPS)
|
||||
image, text = stable_diff_inf(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
steps=STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
seed=seed,
|
||||
scheduler_key=SELECTED_SCHEDULER,
|
||||
variant=SELECTED_MODEL,
|
||||
device_key=available_devices[0],
|
||||
)
|
||||
|
||||
return image, seed
|
||||
|
||||
|
||||
async def generate_and_send_photo(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
progress_msg = await update.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=update.message.message_id
|
||||
)
|
||||
im, seed = generate_image(prompt=update.message.text)
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{update.message.text}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=update.message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
if query.data in models_list:
|
||||
global SELECTED_MODEL
|
||||
SELECTED_MODEL = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected model: {query.data}")
|
||||
return
|
||||
if query.data in sheds_list:
|
||||
global SELECTED_SCHEDULER
|
||||
SELECTED_SCHEDULER = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
|
||||
return
|
||||
replied_message = query.message.reply_to_message
|
||||
await query.answer()
|
||||
progress_msg = await query.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=replied_message.message_id
|
||||
)
|
||||
|
||||
if query.data == "TRYAGAIN":
|
||||
prompt = replied_message.text
|
||||
im, seed = generate_image(prompt)
|
||||
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{prompt}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=replied_message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def select_model_handler(update, context):
|
||||
text = "Select model"
|
||||
keyboard = []
|
||||
for model in models_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=model, callback_data=model),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def select_scheduler_handler(update, context):
|
||||
text = "Select schedule"
|
||||
keyboard = []
|
||||
for shed in sheds_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=shed, callback_data=shed),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def set_steps_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_steps ")[1]
|
||||
global STEPS
|
||||
STEPS = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_steps 30"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_negative_prompt_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_negative_prompt ")[1]
|
||||
global NEGATIVE_PROMPT
|
||||
NEGATIVE_PROMPT = input_args
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_negative_prompt ugly, bad art, mutated"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_guidance_scale_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_guidance_scale ")[1]
|
||||
global GUIDANCE_SCALE
|
||||
GUIDANCE_SCALE = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_guidance_scale 7"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def setup_bot_commands(application: Application) -> None:
|
||||
await application.bot.set_my_commands(
|
||||
[
|
||||
BotCommand("select_model", "to select model"),
|
||||
BotCommand("select_scheduler", "to select scheduler"),
|
||||
BotCommand("set_steps", "to set steps"),
|
||||
BotCommand("set_guidance_scale", "to set guidance scale"),
|
||||
BotCommand("set_negative_prompt", "to set negative prompt"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
app = (
|
||||
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
|
||||
)
|
||||
app.add_handler(CommandHandler("select_model", select_model_handler))
|
||||
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
|
||||
app.add_handler(CommandHandler("set_steps", set_steps_handler))
|
||||
app.add_handler(
|
||||
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
|
||||
)
|
||||
app.add_handler(CallbackQueryHandler(button))
|
||||
log.warning("Start bot")
|
||||
app.run_polling()
|
||||
Reference in New Issue
Block a user