Compare commits

..

1 Commits

Author SHA1 Message Date
dan
2e4403f5ad add generate_sharktank for stable_diffusion model defaults 2023-01-25 20:37:33 +00:00
37 changed files with 481 additions and 744 deletions

View File

@@ -10,14 +10,14 @@ on:
jobs:
windows-build:
runs-on: 7950X
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
@@ -52,10 +52,6 @@ jobs:
./setup_venv.ps1
pyinstaller web/shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\shark\examples\shark_inference\stable_diffusion\shark_sd_cli.spec
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now

View File

@@ -100,9 +100,9 @@ jobs:
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k cpu
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cpu --update_tank
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -112,11 +112,10 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k cuda
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k cuda --update_tank
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
sh build_tools/stable_diff_main_test.sh
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
@@ -127,7 +126,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
pytest -s --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" tank/test_models.py -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
@@ -135,5 +134,4 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest --benchmark --ci --ci_sha=${SHORT_SHA} -s --local_tank_cache="/data/anush/shark_cache" tank/test_models.py -k vulkan --update_tank

View File

@@ -83,15 +83,19 @@ python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=f
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
The output on a 7900XTX would like:
The output on a 6900XT would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
```
Here are some samples generated:

View File

@@ -1,5 +1,5 @@
import argparse
from PIL import Image
import torchvision
import numpy as np
import requests
@@ -22,24 +22,20 @@ def get_image(url, local_filename):
if res.status_code == 200:
with open(local_filename, "wb") as f:
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.01:
subprocess.run(
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")
return torchvision.io.read_image(local_filename).numpy()
if __name__ == "__main__":
args = parser.parse_args()
new = torchvision.io.read_image(args.newfile).numpy() / 255.0
tempfile_name = os.path.join(os.getcwd(), "golden.png")
get_image(args.golden_url, tempfile_name)
compare_images(args.newfile, tempfile_name)
golden = get_image(args.golden_url, tempfile_name) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if not mean < 0.2:
subprocess.run(
["gsutil", "cp", args.newfile, "gs://shark_tank/testdata/builder/"]
)
raise SystemExit("new and golden not close")
else:
print("SUCCESS")

View File

@@ -0,0 +1,6 @@
rm -rf ./test_images
mkdir test_images
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
python build_tools/image_comparison.py -n ./test_images/*.png
exit $?

View File

@@ -1,77 +0,0 @@
import os
import subprocess
from shark.examples.shark_inference.stable_diffusion.resources import (
get_json_file,
)
from shark.shark_downloader import download_public_file
from image_comparison import compare_images
import argparse
from glob import glob
import shutil
model_config_dicts = get_json_file(
os.path.join(
os.getcwd(),
"shark/examples/shark_inference/stable_diffusion/resources/model_config.json",
)
)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned"] #'use_tuned']
devices = ["vulkan"]
if beta:
extra_flags.append("--beta_models=True")
for model_name in hf_model_names:
for use_tune in tuned_options:
command = [
"python",
"shark/examples/shark_inference/stable_diffusion/main.py",
"--device=" + device,
"--output_dir=./test_images/" + model_name,
"--hf_model_id=" + model_name,
use_tune,
]
command += extra_flags
generated_image = not subprocess.call(
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
if generated_image:
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
comparison = [
"python",
"build_tools/image_comparison.py",
"--golden_url=gs://shark_tank/testdata/golden/"
+ model_name
+ "/*.png",
"--newfile=./test_images/" + model_name + "/*.png",
]
test_file = glob("./test_images/" + model_name + "/*.png")[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])

View File

@@ -16,7 +16,7 @@ pip install -r requirements.txt
python annotation_tool.py
```
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
<img width="1308" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214191759-24cc5fe6-cd53-4099-87f6-707068f8888d.png">
* Select a dataset from `Dataset` dropdown list
* Select an image from `Image` dropdown list

View File

@@ -2,12 +2,15 @@ import gradio as gr
import json
import jsonlines
import os
from args import args
from pathlib import Path
from PIL import Image
from utils import get_datasets
# TODO: pass gs_url as a command line flag
# see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
gs_url = "gs://shark-datasets/portraits"
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
@@ -27,15 +30,15 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
elem_id="top_logo",
).style(width=150, height=100)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
datasets, images = get_datasets(gs_url)
prompt_data = dict()
with gr.Row(elem_id="ui_body"):
# TODO: add multiselect dataset, there is a gradio version conflict
# TODO: add multiselect dataset
dataset = gr.Dropdown(label="Dataset", choices=datasets)
image_name = gr.Dropdown(label="Image", choices=[])
with gr.Row(elem_id="ui_body"):
with gr.Row(elem_id="ui_body", visible=True):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
@@ -58,26 +61,27 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
finish = gr.Button("Finish")
def filter_datasets(dataset):
# TODO: execute finish process when switching dataset
if dataset is None:
return gr.Dropdown.update(value=None, choices=[])
# create the dataset dir if doesn't exist and download prompt file
dataset_path = str(shark_root) + "/dataset/" + dataset
# TODO: check if metadata.jsonl exists
prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
if not os.path.exists(dataset_path):
os.mkdir(dataset_path)
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
# read prompt jsonlines file
prompt_data.clear()
if dataset in ds_w_prompts:
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
@@ -88,7 +92,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
# download and load the image
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
# TODO: remove previous image if change image from dropdown
img_gs_path = gs_url + "/" + dataset + "/" + image_name
img_sub_path = "/".join(image_name.split("/")[:-1])
img_dst_path = (
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
@@ -98,8 +103,6 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
img = Image.open(img_dst_path + image_name.split("/")[-1])
if image_name not in prompt_data.keys():
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
@@ -141,8 +144,6 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
@@ -170,8 +171,6 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
@@ -228,7 +227,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = args.gs_url + "/" + dataset + "/"
dataset_gs_path = gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
@@ -241,8 +240,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
if __name__ == "__main__":
shark_web.launch(
share=args.share,
share=False,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
server_port=8080,
)

View File

@@ -1,34 +0,0 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Dataset Annotator flags
##############################################################################
p.add_argument(
"--gs_url",
type=str,
required=True,
help="URL to datasets in GS bucket",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

View File

@@ -4,7 +4,6 @@ from google.cloud import storage
def get_datasets(gs_url):
datasets = set()
images = dict()
ds_w_prompts = []
storage_client = storage.Client()
bucket_name = gs_url.split("/")[2]
@@ -13,17 +12,12 @@ def get_datasets(gs_url):
for blob in blobs:
dataset_name = blob.name.split("/")[1]
if dataset_name == "":
continue
datasets.add(dataset_name)
if dataset_name not in images.keys():
images[dataset_name] = []
# check if image or jsonl
file_sub_path = "/".join(blob.name.split("/")[2:])
# check if image or jsonl
if "/" in file_sub_path:
if dataset_name not in images.keys():
images[dataset_name] = []
images[dataset_name] += [file_sub_path]
elif "metadata.jsonl" in file_sub_path:
ds_w_prompts.append(dataset_name)
return list(datasets), images, ds_w_prompts
return list(datasets), images

View File

@@ -14,10 +14,24 @@ import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
import tensorflow as tf
import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from shark.examples.shark_inference.stable_diffusion import (
model_wrappers as mw,
)
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
def create_hash(file_name):
@@ -29,6 +43,21 @@ def create_hash(file_name):
return file_hash.hexdigest()
def get_folder_name(
model_basename, device, precision_value, length, version=None
):
return (
model_basename
+ "_"
+ device
+ "_"
+ precision_value
+ "_maxlen_"
+ length
+ "_torch"
)
def save_torch_model(torch_model_list):
from tank.model_utils import (
get_hf_model,
@@ -51,6 +80,69 @@ def save_torch_model(torch_model_list):
model = None
input = None
if model_type == "fx_imported":
from shark.examples.shark_inference.stable_diffusion.stable_args import (
args,
)
args.use_tuned = False
args.import_mlir = True
args.use_tuned = False
args.local_tank_cache = WORKDIR
base_sd_versions = ["v1_4", "v2_1"]
model_variants = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
]
scheduler_types = [
"PNDM",
"DDIM",
"LMSDiscrete",
"EulerDiscrete",
"DPMSolverMultistep",
"SharkEulerDiscrete",
]
precision_values = ["fp16"]
seq_lengths = [64, 77]
for device in ["vulkan", "cuda"]:
args.device = device
for model_variant in model_variants:
args.variant = model_variant
for base_sd_ver in base_sd_versions:
# model variants not required for non base sd models
if (
base_sd_ver == "v1_4"
and not model_variant == "stablediffusion"
):
continue
else:
args.version = base_sd_ver
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id="stabilityai/stable-diffusion-2-1-base",
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
debug=True,
sharktank_dir=WORKDIR,
)
args.max_length = length
model_name = f"{args.variant}/{args.version}/{torch_model_name}/{args.precision}/length_{args.max_length}{args.use_tuned}"
torch_model_dir = os.path.join(
WORKDIR, model_name
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(torch_model_name)
elif model_type == "hf":
@@ -99,17 +191,6 @@ def save_tf_model(tf_model_list):
get_keras_model,
get_TFhf_model,
)
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
@@ -243,13 +324,13 @@ if __name__ == "__main__":
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
# if args.tf_model_csv:
# save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
# if args.tflite_model_csv:
# save_tflite_model(args.tflite_model_csv)
if args.upload:
if True:
git_hash = sp.getoutput("git log -1 --format='%h'") + "/"
print("uploading files to gs://shark_tank/" + git_hash)
os.system(f"gsutil cp -r {WORKDIR}* gs://shark_tank/" + git_hash)

View File

@@ -10,7 +10,6 @@ google-cloud-storage
# Testing
pytest
pytest-xdist
pytest-forked
Pillow
parameterized
@@ -21,9 +20,6 @@ scipy
ftfy
gradio
altair
omegaconf
safetensors
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

View File

@@ -14,51 +14,26 @@ Currently we support fine-tuned versions of Stable Diffusion such as:
use the flag `--hf_model_id=` to specify the repo-id of the model to be used.
```shell
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden" --no-use_tuned
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
```
## Run a custom model using a `.ckpt` / `.safetensors` checkpoint file:
* Ensure you don't have any `.yaml` file at the root directory of SHARK - best would be to ensure you're on the latest `main` branch and use `--clear_all` the first time you're running the command for inference.
* Install `pytorch_lightning` by running :-
## Run a custom model using a `.ckpt` file:
* Install the following by running :-
```shell
pip install pytorch_lightning
pip install omegaconf safetensors pytorch_lightning
```
NOTE: This is needed to process [ckpt file of runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt).
* Download a [.ckpt](https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp32.ckpt) file in case you don't have a locally generated `.ckpt` file for StableDiffusion.
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following :-
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following (note the `hf_model_id` flag which states what the base model is from which the `.ckpt` model was fined-tuned off of) :-
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file" --no-use_tuned
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file" --hf_model_id="CompVis/stable-diffusion-v1-4"
```
* We use a combination of 2 flags to make this feature work : `import_mlir` and `ckpt_loc`.
* In case `ckpt_loc` is NOT specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, two ways to use `import_mlir` :-
- With `hf_model_id` to run HuggingFace's StableDiffusion variants.
- With `ckpt_loc` to run a StableDiffusion variant with a `.ckpt` or `.safetensors` checkpoint file
* We use a combination of 3 flags to make this feature work : `import_mlir`, `ckpt_loc` and `hf_model_id`, of which `import_mlir` needs to be present. In case `ckpt_loc` is not specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you need to specify which base model's `.ckpt` you are using via `hf_model_id`.
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images. And in case you want to use any variants from HuggingFace then add the mapping of the variant to their base model in [variants.json](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/resources/variants.json).
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images.
* You may also try out [.safetensors file of Protogen x3.4 of civitai.com](https://civitai.com/models/3666/protogen-x34-photorealism-official-release) and provide the `.safetensors` path to `ckpt_loc` flag.
* NOTE: Ensure that the `.ckpt` or `.safetensors` file are part of the path passed to `ckpt_loc` flag. Eg: `--ckpt_loc="/path/to/checkpoint/file/name_of_checkpoint.ckpt` OR `--ckpt_loc="/path/to/checkpoint/file/name_of_checkpoint.safetensors`. Also ensure that you're using `--no-use_tuned` flag in your run command.
## Running the model for a `batch_size` and for a set of `runs`:
We currently support batch size in the range `[1, 3]`.
You can specify batch size using `batch_size` flag (defaults to `1`) and the number of times you want to run the model using `runs` flag (defaults to `1`).
In total, you'll be able to generate `batch_size * runs` number of images.
- Usage 1: Using the same prompt -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=3 --no-use_tuned
```
The example above generates `3` different images in total with the same prompt `tajmahal, oil on canvas, sunflowers, 4k, uhd`.
- Usage 2: Using different prompts -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=3 -p="batman riding a horse, oil on canvas, 4k, uhd" -p="superman riding a horse, oil on canvas, 4k, uhd" --no-use_tuned
```
The example above generates `1` image for each different prompt, thus generating `3` images in total.
- Usage 3: Using `runs` -
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --hf_model_id="runwayml/stable-diffusion-v1-5" --batch_size=2 --runs=3 --no-use_tuned
```
The example above generates `6` different images in total, `2` images for each `runs`.
</details>
<details>

View File

@@ -1,11 +1,6 @@
import os
import sys
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
os.environ["AMD_ENABLE_LLPC"] = "1"
from transformers import CLIPTextModel, CLIPTokenizer
import torch
@@ -37,12 +32,6 @@ if args.clear_all:
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
@@ -59,6 +48,7 @@ from schedulers import (
SharkEulerDiscreteScheduler,
)
import time
import sys
from shark.iree_utils.compile_utils import dump_isas
# Helper function to profile the vulkan device.
@@ -82,10 +72,6 @@ if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
# Make it as default prompt
if len(args.prompts) == 0:
args.prompts = ["cyberpunk forest by Salvador Dali"]
prompt = args.prompts
neg_prompt = args.negative_prompts
height = args.height
@@ -95,20 +81,12 @@ if __name__ == "__main__":
# Scale for classifier-free guidance
guidance_scale = torch.tensor(args.guidance_scale).to(torch.float32)
batch_size = args.batch_size
prompt = prompt * batch_size if len(prompt) == 1 else prompt
len_of_prompt = len(prompt)
assert (
len_of_prompt == batch_size
), f"no. of prompts ({len_of_prompt}) is not equal to batch_size ({batch_size})"
print("Running StableDiffusion with the following config :-")
print(f"Batch size : {batch_size}")
print(f"Prompts : {prompt}")
print(f"Runs : {args.runs}")
# Try to make neg_prompt equal to batch_size by appending blank strings.
for i in range(batch_size - len(neg_prompt)):
neg_prompt.append("")
# TODO: Add support for batch_size > 1.
batch_size = len(prompt)
if batch_size != 1:
sys.exit("More than one prompt is not supported yet.")
if batch_size != len(neg_prompt):
sys.exit("prompts and negative prompts must be of same length")
set_init_device_flags()
disk_space_check(Path.cwd())
@@ -120,21 +98,16 @@ if __name__ == "__main__":
unet = get_unet()
vae = get_vae()
else:
if args.ckpt_loc != "":
assert args.ckpt_loc.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
if ".ckpt" in args.ckpt_loc:
preprocessCKPT()
mlir_import = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
args.precision,
max_len=args.max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
)
clip, unet, vae = mlir_import()
@@ -309,17 +282,17 @@ if __name__ == "__main__":
disk_space_check(output_path, lim=5)
for i in range(batch_size):
json_store = {
"prompt": prompt[i],
"prompt": args.prompts[i],
"negative prompt": args.negative_prompts[i],
"seed": seed,
"seed": args.seed,
"hf_model_id": args.hf_model_id,
"precision": args.precision,
"steps": args.steps,
"guidance_scale": args.guidance_scale,
"scheduler": args.scheduler,
}
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", prompt[i][:15])
img_name = f"{prompt_slice}_{seed}_{run}_{i}_{dt.now().strftime('%y%m%d_%H%M%S')}"
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[i][:15])
img_name = f"{prompt_slice}_{args.seed}_{run}_{dt.now().strftime('%y%m%d_%H%M%S')}"
if args.output_img_format == "jpg":
pil_images[i].save(
output_path / f"{img_name}.jpg",

View File

@@ -1,14 +1,17 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx, get_opt_flags
from resources import base_models
from resources import base_models, variants
from collections import defaultdict
import torch
import sys
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
def replace_shape_str(shape, max_len, width, height):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
@@ -17,17 +20,13 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
def get_input_info(model_info, max_len, width, height):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
@@ -36,9 +35,7 @@ def get_input_info(model_info, max_len, width, height, batch_size):
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
clean_shape = replace_shape_str(shape, max_len, width, height)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
@@ -48,9 +45,25 @@ def get_input_info(model_info, max_len, width, height, batch_size):
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
# Returns the model configuration in a dict containing input parameters
# for clip, unet and vae respectively.
def get_model_configuration(model_id, max_len, width, height):
if model_id in base_models:
return get_input_info(base_models[model_id], max_len, width, height)
elif model_id in variants:
return get_input_info(
base_models[variants[model_id]], max_len, width, height
)
else:
sys.exit(
"The model info is not configured, please add the model_configuration in base_model.json if it's a base model, else add it in the variant.json"
)
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -60,22 +73,19 @@ class SharkifyStableDiffusionModel:
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
debug: bool = False,
sharktank_dir: str = "",
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.inputs = get_model_configuration(
model_id, max_len, width // 8, height // 8
)
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
str(max_len)
+ "_"
+ str(height)
+ "_"
@@ -83,7 +93,8 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
self.use_tuned = use_tuned
self.debug = debug
self.sharktank_dir = sharktank_dir
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -130,13 +141,18 @@ class SharkifyStableDiffusionModel:
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
vae_model_name = vae_name + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, vae_model_name), exist_ok=True
)
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
model_name=vae_name + self.model_name,
use_tuned=self.use_tuned,
model_name=vae_model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
debug=self.debug,
)
return shark_vae
@@ -169,14 +185,20 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
unet_model_name = "unet" + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, unet_model_name),
exist_ok=True,
)
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
model_name=unet_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
debug=self.debug,
)
return shark_unet
@@ -193,59 +215,24 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText()
clip_model_name = "clip" + self.model_name
if self.debug:
os.makedirs(
os.path.join(self.sharktank_dir, clip_model_name),
exist_ok=True,
)
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
model_name=clip_model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
debug=self.debug,
)
return shark_clip
def __call__(self):
from utils import get_vmfb_path_name
from stable_args import args
import traceback, functools, operator, os
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)
compiled_clip = self.get_clip()
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
return compiled_clip, compiled_unet, compiled_vae

View File

@@ -33,5 +33,10 @@ models_db = get_json_file("resources/model_db.json")
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# The variant contains the mapping from variant to the base configuration
# to get the required inputs.
# If the input configuration doesn't match it should be registered standalone in the base configuration.
variants = get_json_file("resources/variants.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -3,7 +3,7 @@
"unet": {
"latents": {
"shape": [
"1*batch_size",
1,
4,
"height",
"width"
@@ -18,7 +18,7 @@
},
"embedding": {
"shape": [
"2*batch_size",
2,
"max_len",
1024
],
@@ -32,7 +32,7 @@
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
1,4,"height","width"
],
"dtype":"f32"
}
@@ -40,7 +40,7 @@
"clip": {
"token" : {
"shape" : [
"2*batch_size",
2,
"max_len"
],
"dtype":"i64"
@@ -51,7 +51,7 @@
"unet": {
"latents": {
"shape": [
"1*batch_size",
1,
4,
"height",
"width"
@@ -66,7 +66,7 @@
},
"embedding": {
"shape": [
"2*batch_size",
2,
"max_len",
768
],
@@ -80,7 +80,7 @@
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
1,4,"height","width"
],
"dtype":"f32"
}
@@ -88,7 +88,7 @@
"clip": {
"token" : {
"shape" : [
"2*batch_size",
2,
"max_len"
],
"dtype":"i64"

View File

@@ -1,6 +1,6 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
"stablediffusion/untuned":"gs://shark_tank/latest",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",

View File

@@ -31,24 +31,18 @@
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": ["--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"]
}
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
},
"untuned": {

View File

@@ -0,0 +1,20 @@
{
"runwayml/stable-diffusion-v1-5": "CompVis/stable-diffusion-v1-4",
"prompthero/openjourney": "CompVis/stable-diffusion-v1-4",
"Linaqruf/anything-v3.0": "CompVis/stable-diffusion-v1-4",
"stabilityai/stable-diffusion-2-1-base": "stabilityai/stable-diffusion-2-1",
"dreamlike-art/dreamlike-diffusion-1.0": "CompVis/stable-diffusion-v1-4",
"eimiss/EimisAnimeDiffusion_1.0v": "CompVis/stable-diffusion-v1-4",
"claudfuen/photorealistic-fuen-v1": "CompVis/stable-diffusion-v1-4",
"nitrosocke/Nitro-Diffusion": "CompVis/stable-diffusion-v1-4",
"stabilityai/stable-diffusion-2-base": "stabilityai/stable-diffusion-2-1",
"wavymulder/Analog-Diffusion": "CompVis/stable-diffusion-v1-4",
"nitrosocke/redshift-diffusion": "CompVis/stable-diffusion-v1-4",
"wavymulder/portraitplus": "CompVis/stable-diffusion-v1-4",
"Linaqruf/anything-v3-better-vae": "CompVis/stable-diffusion-v1-4",
"nitrosocke/Arcane-Diffusion": "CompVis/stable-diffusion-v1-4",
"hakurei/waifu-diffusion": "stabilityai/stable-diffusion-2-1",
"lambdalabs/sd-pokemon-diffusers": "CompVis/stable-diffusion-v1-4",
"prompthero/openjourney-v2": "CompVis/stable-diffusion-v1-4",
"andite/anything-v4.0": "CompVis/stable-diffusion-v1-4"
}

View File

@@ -15,19 +15,10 @@ import torch
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = len(args.prompts)
if len(args.prompts) == 0:
BATCH_SIZE = 1
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"latent": torch.randn(1, 4, args.height // 8, args.width // 8),
"output": torch.randn(1, 4, args.height // 8, args.width // 8),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
@@ -93,7 +84,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
model_name=f"euler_scale_model_input_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
@@ -102,7 +93,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
model_name=f"euler_step_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)

View File

@@ -1,6 +1,6 @@
import os
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.iree_utils._common import run_cmd, iree_target_map
from shark.shark_downloader import (
download_model,
download_public_file,
@@ -8,95 +8,74 @@ from shark.shark_downloader import (
)
from shark.parser import shark_args
from stable_args import args
from opt_params import get_params
from utils import set_init_device_flags
set_init_device_flags()
device = (
args.device if "://" not in args.device else args.device.split("://")[0]
)
# Downloads the model (Unet or VAE fp16) from shark_tank
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{args.variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from opt_params import get_params, version, variant
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
# Downloads the tuned config files from shark_tank
config_bucket = "gs://shark_tank/sd_tuned/configs/"
if args.use_winograd:
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
from opt_params import version, variant
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
if args.annotation_model == "unet" or device == "cuda":
if args.variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_version = "v1_4"
args.version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if args.use_winograd:
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
input_contents=mlir_model,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
winograd=args.use_winograd,
)
with open(out_file_path, "w") as f:
with open(
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
) as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
if use_winograd:
if args.annotation_model == "unet" or device == "cuda":
if args.use_winograd:
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
else:
input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
dump_after = "iree-flow-pad-linalg-ops"
# Dump IR after padding/img2col/winograd passes
@@ -111,8 +90,6 @@ def annotate_with_lower_configs(
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
@@ -139,53 +116,7 @@ def annotate_with_lower_configs(
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
with open(output_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -1,76 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'resources/prompts.json', 'resources'),
( 'resources/model_db.json', 'resources'),
( 'resources/base_model.json', 'resources'),
( 'resources/opt_flags.json', 'resources'),
]
binaries = []
block_cipher = None
a = Analysis(
['main.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core' ],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -15,10 +15,9 @@ p = argparse.ArgumentParser(
##############################################################################
p.add_argument(
"-p",
"--prompts",
action="append",
default=[],
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
)
@@ -43,14 +42,6 @@ p.add_argument(
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
@@ -93,7 +84,7 @@ p.add_argument(
p.add_argument(
"--import_mlir",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
@@ -168,13 +159,6 @@ p.add_argument(
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
@@ -299,4 +283,29 @@ p.add_argument(
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
##############################################################################
### CI generation tags
##############################################################################
# TODO: remove from here once argparse is not required by half of sd
p.add_argument(
"--upload",
default=True,
action="store_true",
help="used for generate_sharktank.py to upload models",
)
p.add_argument(
"--ci_tank_dir",
default=True,
action="store_true",
help="used for CI generation purposes only.",
)
args = p.parse_args()

View File

@@ -12,23 +12,22 @@ If it works well for you, please "star" the following GitHub projects... this is
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
First, download this special driver in a folder of your choice. We recommend you keep that driver around since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver's version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, if you use a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [469 here](https://github.com/nod-ai/SHARK/releases/download/20230124.469/shark_sd_20230124_469.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Download the latest Windows SHARK SD binary [455 here](https://storage.googleapis.com/shark-public/windows/shark_sd_20230120_455.exe) in a folder of your choice. If you want nighly builds you can look for them in the github releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR, that can get outdated if you run multiple EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* Your browser may warn you about downloading an .exe file
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear all the local artifacts with `--clean_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
@@ -60,7 +59,7 @@ Here are some samples generated:
<summary>Advanced Installation </summary>
## Setup your Python Virtual Environment and Dependencies
## Setup your Python VirtualEnvironment and Dependencies
<details>
<summary> Windows 10/11 Users </summary>
@@ -134,15 +133,19 @@ python3.10 shark/examples/shark_inference/stable_diffusion/main.py --precision=f
</details>
The output on a 7900XTX would like:
The output on a 6900XT would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
44it [00:08, 5.14it/s]i = 44 t = 120 (191ms)
45it [00:08, 5.15it/s]i = 45 t = 100 (191ms)
46it [00:08, 5.16it/s]i = 46 t = 80 (191ms)
47it [00:09, 5.16it/s]i = 47 t = 60 (193ms)
48it [00:09, 5.15it/s]i = 48 t = 40 (195ms)
49it [00:09, 5.12it/s]i = 49 t = 20 (196ms)
50it [00:09, 5.14it/s]
Average step time: 192.8154182434082ms/it
Total image generation runtime (s): 10.390909433364868
(shark.venv) PS C:\g\shark>
```
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)

View File

@@ -1,7 +1,8 @@
import os
import gc
import tempfile
import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.examples.shark_inference.stable_diffusion.stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
@@ -9,27 +10,18 @@ from shark.iree_utils.vulkan_utils import (
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from resources import opt_flags
from sd_annotation import sd_model_annotation
import sys
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def get_vmfb_path_name(model_name):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -79,36 +71,21 @@ def compile_through_fx(
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
save_dir=tempfile.gettempdir(),
debug=False,
):
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
save_dir = os.path.join(args.local_tank_cache, model_name)
print("SAVE DIR: " + save_dir)
mlir_module, func_name, = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=model_name,
save_dir=save_dir,
)
if use_tuned:
model_name = model_name + "_tuned"
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
shark_module = SharkInference(
mlir_module,
device=args.device,
@@ -235,38 +212,36 @@ def set_init_device_flags():
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
or args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in [
"sm_80",
"sm_84",
"sm_86",
"sm_89",
]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
if (
args.hf_model_id
in [
"stabilityai/stable-diffusion-2-1-base",
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
]
and args.precision == "fp16"
and "cuda" in args.device
and get_cuda_sm_cc() == "sm_80"
):
args.use_tuned = True
if args.use_tuned:
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
else:
@@ -322,11 +297,6 @@ def get_opt_flags(model, precision="fp16"):
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
@@ -343,6 +313,7 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
@@ -361,21 +332,25 @@ def preprocessCKPT():
diffusers_path,
)
path_to_diffusers = complete_path_to_diffusers.as_posix()
from_safetensors = (
True if args.ckpt_loc.lower().endswith(".safetensors") else False
# TODO: Use the SD to Diffusers CKPT pipeline once it's included in the release.
sd_to_diffusers = os.path.join(os.getcwd(), "sd_to_diffusers.py")
if not os.path.isfile(sd_to_diffusers):
url = "https://raw.githubusercontent.com/huggingface/diffusers/8a3f0c1f7178f4a3d5a5b21ae8c2906f473e240d/scripts/convert_original_stable_diffusion_to_diffusers.py"
import requests
req = requests.get(url)
open(sd_to_diffusers, "wb").write(req.content)
print("Downloaded SD to Diffusers converter")
else:
print("SD to Diffusers converter already exists")
os.system(
"python "
+ sd_to_diffusers
+ " --checkpoint_path="
+ args.ckpt_loc
+ " --dump_path="
+ path_to_diffusers
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print("Loading pipeline from original stable diffusion checkpoint")
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=args.ckpt_loc,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
args.ckpt_loc = path_to_diffusers
print("Custom model path is : ", args.ckpt_loc)

View File

@@ -276,19 +276,9 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device, device_idx=None):
def get_iree_module(flatbuffer_blob, device):
# Returns the compiled module and the configs.
if device_idx is not None:
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"]
)
# haldevice = haldriver.create_default_device()
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
@@ -304,21 +294,20 @@ def get_iree_compiled_module(
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
return get_iree_module(flatbuffer_blob, device)
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
def load_flatbuffer(flatbuffer_path: str, device: str):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
return get_iree_module(flatbuffer_blob, device)
def export_iree_module_to_vmfb(

View File

@@ -47,9 +47,6 @@ def model_annotation(
input_contents = f.read()
module = ir.Module.parse(input_contents)
if config_path == "":
return module
if winograd:
with open(config_path, "r") as f:
data = json.load(f)
@@ -165,6 +162,7 @@ def walk_children(
add_attributes(
child_op, configs[child_op_shape]["options"][0]
)
print(f"Updated op {child_op}", file=sys.stderr)
walk_children(child_op, configs, search_op, winograd)
@@ -396,6 +394,7 @@ def add_winograd_attribute(op: ir.Operation, config: List):
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), 1
)
print("Apply Winograd on selected conv op: ", op)
def add_attribute_by_name(op: ir.Operation, name: str, val: int):

View File

@@ -55,7 +55,6 @@ class SharkImporter:
inputs: tuple = (),
frontend: str = "torch",
raw_model_file: str = "",
return_str: bool = False,
):
self.module = module
self.inputs = None if len(inputs) == 0 else inputs
@@ -66,7 +65,6 @@ class SharkImporter:
)
sys.exit(1)
self.raw_model_file = raw_model_file
self.return_str = return_str
# NOTE: The default function for torch is "forward" and tf-lite is "main".
@@ -74,11 +72,7 @@ class SharkImporter:
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
self.module,
self.inputs,
is_dynamic,
tracing_required,
self.return_str,
self.module, self.inputs, is_dynamic, tracing_required
)
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
@@ -136,6 +130,7 @@ class SharkImporter:
):
import numpy as np
print("dir in save data:" + dir)
inputs_name = "inputs.npz"
outputs_name = "golden_out.npz"
func_file_name = "function_name"
@@ -164,6 +159,7 @@ class SharkImporter:
func_name="forward",
dir=tempfile.gettempdir(),
model_name="model",
golden_values=None,
):
if self.inputs == None:
print(
@@ -183,7 +179,11 @@ class SharkImporter:
if self.frontend in ["torch", "pytorch"]:
import torch
golden_out = self.module(*self.inputs)
golden_out = None
if golden_values is not None:
golden_out = golden_values
else:
golden_out = self.module(*self.inputs)
if torch.is_tensor(golden_out):
golden_out = tuple(
golden_out.detach().cpu().numpy(),
@@ -363,12 +363,16 @@ def import_with_fx(
f16_input_mask=None,
debug=False,
training=False,
return_str=False,
save_dir=tempfile.gettempdir(),
model_name="model",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
golden_values = None
if debug:
golden_values = model(*inputs)
# TODO: Control the decompositions.
fx_g = make_fx(
model,
@@ -419,11 +423,12 @@ def import_with_fx(
ts_graph,
inputs,
frontend="torch",
return_str=return_str,
)
if debug and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
if debug: # and not is_f16:
(mlir_module, func_name), _, _ = mlir_importer.import_debug(
dir=save_dir, model_name=model_name, golden_values=golden_values
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()

View File

@@ -69,13 +69,11 @@ class SharkInference:
is_benchmark: bool = False,
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
self.device_idx = device_idx
self.dispatch_benchmarks = (
shark_args.dispatch_benchmarks
if dispatch_benchmark is None
@@ -122,7 +120,6 @@ class SharkInference:
self.device,
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
)
if self.dispatch_benchmarks is not None:
@@ -208,6 +205,5 @@ class SharkInference:
) = load_flatbuffer(
path,
self.device,
self.device_idx,
)
return

View File

@@ -64,13 +64,11 @@ class SharkRunner:
mlir_dialect: str = "linalg",
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.device_idx = device_idx
if check_device_drivers(self.device):
print(device_driver_info(self.device))
@@ -86,7 +84,6 @@ class SharkRunner:
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
def run(self, function_name, inputs: tuple, send_to_host=False):

View File

@@ -56,7 +56,6 @@ def get_torch_mlir_module(
input: tuple,
dynamic: bool,
jit_trace: bool,
return_str: bool = False,
):
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
@@ -71,11 +70,9 @@ def get_torch_mlir_module(
module,
input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=jit_trace,
use_tracing=True,
ignore_traced_shapes=ignore_traced_shapes,
)
if return_str:
return mlir_module.operation.get_asm()
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()

View File

@@ -14,7 +14,7 @@ microsoft/MiniLM-L12-H384-uncased,mhlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,
microsoft/layoutlm-base-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,""
microsoft/mpnet-base,mhlo,tf,1e-2,1e-2,default,None,False,False,False,""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir"
alexnet,linalg,torch,1e-2,1e-3,default,None,True,False,True,"https://github.com/nod-ai/SHARK/issues/879"
alexnet,linalg,torch,1e-2,1e-3,default,None,False,False,True,"Assertion Error: Zeros Output"
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False True Vulkan Numerical Error: mostly conv
14 microsoft/layoutlm-base-uncased mhlo tf 1e-2 1e-3 default None False False False
15 microsoft/mpnet-base mhlo tf 1e-2 1e-2 default None False False False
16 albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
17 alexnet linalg torch 1e-2 1e-3 default None True False False True https://github.com/nod-ai/SHARK/issues/879 Assertion Error: Zeros Output
18 bert-base-cased linalg torch 1e-2 1e-3 default None False False False
19 bert-base-uncased linalg torch 1e-2 1e-3 default None False False False
20 bert-base-uncased_fp16 linalg torch 1e-1 1e-1 default None True False True

View File

@@ -177,11 +177,26 @@ class SharkModuleTester:
if self.ci == True:
self.upload_repro()
if self.benchmark == True:
# p = multiprocessing.Process(
# target=self.benchmark_module,
# args=(shark_module, inputs, dynamic, device),
# )
# p.start()
# p.join()
self.benchmark_module(shark_module, inputs, dynamic, device)
print(msg)
pytest.xfail(reason="Numerics Issue, awaiting triage.")
pytest.xfail(reason="Numerics Issue")
if self.benchmark == True:
# We must create a new process each time we benchmark a model to allow
# for Tensorflow to release GPU resources. Using the same process to
# benchmark multiple models leads to OOM.
# p = multiprocessing.Process(
# target=self.benchmark_module,
# args=(shark_module, inputs, dynamic, device),
# )
# p.start()
# p.join()
self.benchmark_module(shark_module, inputs, dynamic, device)
if self.save_repro == True:

View File

@@ -18,3 +18,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
stabilityai/stable-diffusion-2-1-base, True,fx_imported,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
18 mnasnet1_0 False vision True - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
19 resnet50_fp16 False vision True 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
20 bert-base-uncased_fp16 True fp16 False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
21 stabilityai/stable-diffusion-2-1-base True fx_imported False ??M stable diffusion 2.1 base, LLM, Text to image N/A

View File

@@ -1,13 +1,7 @@
import os
import sys
from pathlib import Path
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
os.environ["AMD_ENABLE_LLPC"] = "1"
import gradio as gr
from PIL import Image
from models.stable_diffusion.resources import resource_path, prompt_examples

View File

@@ -19,8 +19,6 @@ datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')