mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
23 Commits
20230106.4
...
20230113.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee0009d4b8 | ||
|
|
9d851c3346 | ||
|
|
5d117af8ae | ||
|
|
bb41c2d15e | ||
|
|
eba138ee4a | ||
|
|
3b2bbb74f8 | ||
|
|
dbc0f81211 | ||
|
|
d0b613d22e | ||
|
|
72f29b67d5 | ||
|
|
9570045cc3 | ||
|
|
e4efdb5cbb | ||
|
|
187f0fa70c | ||
|
|
472185c3e4 | ||
|
|
f94a571773 | ||
|
|
183e447d35 | ||
|
|
12f844d93a | ||
|
|
47a119a37f | ||
|
|
ee56559b9a | ||
|
|
00e594deea | ||
|
|
6ad9b213b9 | ||
|
|
e4375e8195 | ||
|
|
487bf8e29b | ||
|
|
fea1694e74 |
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py --upload=False --ci_tank_dir=True
|
||||
|
||||
@@ -41,9 +41,12 @@ def create_hash(file_name):
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
from tank.model_utils import get_hf_model
|
||||
from tank.model_utils import get_vision_model
|
||||
from tank.model_utils import get_hf_img_cls_model
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
)
|
||||
|
||||
with open(torch_model_list) as csvfile:
|
||||
torch_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -65,7 +68,8 @@ def save_torch_model(torch_model_list):
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
numpy==1.22.4
|
||||
torchvision
|
||||
torchtriton
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
param([string]$arguments)
|
||||
|
||||
if ($arguments -eq "--update-src"){
|
||||
git pull
|
||||
}
|
||||
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
@@ -123,8 +123,12 @@ fi
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
TORCH_VERSION=${T_VER:9:17}
|
||||
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
|
||||
@@ -17,6 +17,10 @@ from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from stable_args import args
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
@@ -250,5 +254,27 @@ if __name__ == "__main__":
|
||||
pil_images = [
|
||||
transform(image) for image in torch.from_numpy(images).to(torch.uint8)
|
||||
]
|
||||
|
||||
if args.output_dir is not None:
|
||||
output_path = Path(args.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_path = Path.cwd()
|
||||
for i in range(batch_size):
|
||||
pil_images[i].save(f"{args.prompts[i]}_{i}.jpg")
|
||||
json_store = {
|
||||
"prompt": args.prompts[i],
|
||||
"negative prompt": args.negative_prompts[i],
|
||||
"seed": args.seed,
|
||||
"variant": args.variant,
|
||||
"precision": args.precision,
|
||||
"steps": args.steps,
|
||||
"guidance_scale": args.guidance_scale,
|
||||
"scheduler": args.scheduler,
|
||||
}
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[i][:15])
|
||||
img_name = f"{prompt_slice}_{args.seed}_{i}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
pil_images[i].save(
|
||||
output_path / f"{img_name}.jpg", quality=95, subsampling=0
|
||||
)
|
||||
with open(output_path / f"{img_name}.json", "w") as f:
|
||||
f.write(json.dumps(json_store, indent=4))
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket}/{model_key} is not present in the models database"
|
||||
f"{bucket_key}/{model_key} is not present in the models database"
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -62,8 +62,13 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
def get_unet():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{args.variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
@@ -76,8 +81,13 @@ def get_vae():
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{args.variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{args.variant}/{is_tuned}"
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
{
|
||||
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
|
||||
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
|
||||
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
|
||||
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
|
||||
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
|
||||
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
|
||||
@@ -23,10 +24,13 @@
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
|
||||
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
|
||||
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
|
||||
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",
|
||||
|
||||
@@ -12,11 +12,14 @@ from opt_params import get_params
|
||||
from utils import set_init_device_flags
|
||||
|
||||
|
||||
# Downloads the model (Unet or VAE fp16) from shark_tank
|
||||
set_init_device_flags()
|
||||
device = (
|
||||
args.device if "://" not in args.device else args.device.split("://")[0]
|
||||
)
|
||||
|
||||
# Downloads the model (Unet or VAE fp16) from shark_tank
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
use_winograd = True
|
||||
if args.annotation_model == "unet":
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
@@ -34,29 +37,32 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
|
||||
# Downloads the tuned config files from shark_tank
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
if use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd.json"
|
||||
if args.use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
|
||||
if args.annotation_model == "unet":
|
||||
if args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
if args.annotation_model == "unet" or device == "cuda":
|
||||
if (
|
||||
args.variant in ["anythingv3", "analogdiffusion"]
|
||||
or args.annotation_model == "vae"
|
||||
):
|
||||
args.max_length = 77
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
if use_winograd:
|
||||
if args.use_winograd:
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_model,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=use_winograd,
|
||||
winograd=args.use_winograd,
|
||||
)
|
||||
with open(
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
|
||||
@@ -64,8 +70,8 @@ if use_winograd:
|
||||
f.write(str(winograd_model))
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
if args.annotation_model == "unet":
|
||||
if use_winograd:
|
||||
if args.annotation_model == "unet" or device == "cuda":
|
||||
if args.use_winograd:
|
||||
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
else:
|
||||
@@ -73,11 +79,22 @@ if args.annotation_model == "unet":
|
||||
dump_after = "iree-flow-pad-linalg-ops"
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
device_spec_args = ""
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args += flag + " "
|
||||
elif device == "vulkan":
|
||||
device_spec_args = (
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(args.device)} "
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
f"--iree-hal-target-backends={iree_target_map(device)} "
|
||||
f"{device_spec_args}"
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
"--iree-flow-enable-padding-linalg-ops "
|
||||
|
||||
@@ -123,6 +123,12 @@ p.add_argument(
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
)
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -247,4 +253,11 @@ p.add_argument(
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_winograd",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Apply Winograd on selected conv ops.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
|
||||
Then create in the directory web file .env
|
||||
In it the record:
|
||||
TG_TOKEN="your_token"
|
||||
specifying your bot's token from previous step.
|
||||
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
|
||||
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
|
||||
|
||||
Bot commands:
|
||||
/select_model
|
||||
/select_scheduler
|
||||
/set_steps "integer number of steps"
|
||||
/set_guidance_scale "integer number"
|
||||
/set_negative_prompt "negative text"
|
||||
Any other text triggers the creation of an image based on it.
|
||||
@@ -7,6 +7,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -46,6 +47,8 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
@@ -185,7 +188,7 @@ def set_init_device_flags():
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
@@ -193,14 +196,24 @@ def set_init_device_flags():
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
|
||||
if (
|
||||
args.variant == "stablediffusion"
|
||||
and args.precision == "fp16"
|
||||
and "cuda" in args.device
|
||||
and get_cuda_sm_cc() == "sm_80"
|
||||
and args.version == "v2_1base"
|
||||
):
|
||||
args.use_tuned = True
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
|
||||
@@ -64,7 +64,7 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
shark_module = SharkInference(
|
||||
"hello",
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -29,25 +30,16 @@ def get_cpu_count():
|
||||
|
||||
# Get the default cpu args.
|
||||
def get_iree_cpu_args():
|
||||
find_triple_cmd = "uname -s -m"
|
||||
os_name, proc_name = (
|
||||
subprocess.run(
|
||||
find_triple_cmd, shell=True, stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
.stdout.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
|
||||
if os_name == "Darwin":
|
||||
find_kernel_version_cmd = "uname -r"
|
||||
kernel_version = subprocess.run(
|
||||
find_kernel_version_cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
).stdout.decode("utf-8")
|
||||
kernel_version = uname.release
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
elif os_name == "Windows":
|
||||
target_triple = "x86_64-pc-windows-msvc"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
|
||||
470
shark/iree_utils/vulkan_target_env_utils.py
Normal file
470
shark/iree_utils/vulkan_target_env_utils.py
Normal file
@@ -0,0 +1,470 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
# get version
|
||||
version = get_version(triple=triple)
|
||||
# TODO get revision
|
||||
revision = 120
|
||||
|
||||
# extensions
|
||||
extensions = get_extensions(triple)
|
||||
# get vendor
|
||||
vendor = get_vendor(triple)
|
||||
# get device type
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
def get_vulkan_target_env_flag(vulkan_target_triple):
|
||||
|
||||
target_env = get_vulkan_target_env(vulkan_target_triple)
|
||||
target_env_flag = f"--iree-vulkan-target-env={target_env}"
|
||||
return target_env_flag
|
||||
|
||||
|
||||
def get_version(triple):
|
||||
arch, product, os = triple
|
||||
if os in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if product in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if arch in ["unknown"]:
|
||||
return "v1.1"
|
||||
return "v1.3"
|
||||
|
||||
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
for e in ext_list:
|
||||
res += e + ", "
|
||||
res = f"[{res[:-2]}]"
|
||||
return res
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("VK_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["VK_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_NV_cooperative_matrix")
|
||||
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
def get_vendor(triple):
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
|
||||
return "AMD"
|
||||
if arch == "valhall":
|
||||
return "ARM"
|
||||
if arch == "m1":
|
||||
return "Apple"
|
||||
if arch in ["turing", "ampere"]:
|
||||
return "NVIDIA"
|
||||
if arch == "ardeno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
return "SwiftShader"
|
||||
return "Unknown"
|
||||
print(f"Vendor for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch == "cpu":
|
||||
return "CPU"
|
||||
if arch in ["turing", "ampere"]:
|
||||
return "DiscreteGPU"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
|
||||
if product == "ivega10":
|
||||
return "IntegratedGPU"
|
||||
return "DiscreteGPU"
|
||||
if arch in ["m1", "valhall", "adreno"]:
|
||||
return "IntegratedGPU"
|
||||
print(f"Device type for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
|
||||
cap = OrderedDict()
|
||||
arch, product, os = triple
|
||||
subgroup_feature = {
|
||||
"Basic": 1,
|
||||
"Vote": 2,
|
||||
"Arithmetic": 4,
|
||||
"Ballot": 8,
|
||||
"Shuffle": 16,
|
||||
"ShuffleRelative": 32,
|
||||
"Clustered": 64,
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["maxComputeWorkGroupInvocations"] = 128
|
||||
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["minSubgroupSize"] = None
|
||||
cap["maxSubgroupSize"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
cap["shaderInt16"] = False
|
||||
cap["shaderInt64"] = False
|
||||
cap["storageBuffer16BitAccess"] = False
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = False
|
||||
cap["storageBuffer8BitAccess"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = False
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
|
||||
]
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["minSubgroupSize"] = 64
|
||||
cap["maxSubgroupSize"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 512
|
||||
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
|
||||
|
||||
cap["subgroupSize"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
if os == "android31":
|
||||
cap["subgroupFeatures"].append("Shuffle")
|
||||
cap["subgroupFeatures"].append("ShuffleRelative")
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["subgroupSize"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
]
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "andorid31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroupSize"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
print(
|
||||
f"Architecture {arch} not matched. Using default vulkan target device capability"
|
||||
)
|
||||
|
||||
def get_comma_sep_str(ele_list):
|
||||
l = ""
|
||||
for ele in ele_list:
|
||||
l += f"{ele}, "
|
||||
l = f"[{l[:-2]}]"
|
||||
return l
|
||||
|
||||
res = ""
|
||||
for k, v in cap.items():
|
||||
|
||||
if v is None or v == False:
|
||||
continue
|
||||
if isinstance(v, bool):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "maxComputeWorkGroupSize":
|
||||
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
res += f"{k} = {v}, "
|
||||
res = res[:-2]
|
||||
return res
|
||||
@@ -18,6 +18,7 @@ from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
@@ -97,15 +98,16 @@ def get_vulkan_target_triple(device_name):
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
|
||||
def get_vulkan_triple_flag(device_name="", extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
vulkan_device = (
|
||||
device_name if device_name is not None else get_vulkan_device_name()
|
||||
)
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
else:
|
||||
vulkan_device = device_name
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
@@ -122,11 +124,23 @@ def get_vulkan_triple_flag(device_name=None, extra_args=[]):
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
print(f"Using target triple {arg} from command line args")
|
||||
vulkan_triple_flag = arg
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
res_vulkan_flag.append(vulkan_target_env)
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
|
||||
@@ -23,6 +23,8 @@ from datetime import datetime
|
||||
import time
|
||||
import csv
|
||||
import os
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
@@ -65,6 +67,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
extra_args: list = [],
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.enable_tf32 = shark_args.enable_tf32
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
@@ -107,6 +110,8 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
@@ -114,6 +119,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model = dynamo.optimize("inductor")(frontend_model)
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
@@ -333,7 +339,10 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
else:
|
||||
bench_result["shape_type"] = "static"
|
||||
bench_result["device"] = device_str
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
if "fp16" in modelname:
|
||||
bench_result["data_type"] = "float16"
|
||||
else:
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
for e in engines:
|
||||
(
|
||||
bench_result["param_count"],
|
||||
|
||||
@@ -169,9 +169,12 @@ def download_model(
|
||||
os.path.join(model_dir, "upstream_hash.npy"),
|
||||
single_file=True,
|
||||
)
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
try:
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
except FileNotFoundError:
|
||||
upstream_hash = None
|
||||
if local_hash != upstream_hash:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
|
||||
@@ -17,6 +17,7 @@ albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with at
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,False,False,True,"Assertion Error: Zeros Output"
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile."
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311"
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390"
|
||||
@@ -28,6 +29,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,""
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc,True,False,True,""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/388"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,True,"Vulkan Numerical Error (mostly conv)"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,True,"https://github.com/nod-ai/SHARK/issues/575"
|
||||
|
||||
|
@@ -2,12 +2,14 @@ model_name, use_tracing, dynamic, param_count, tags, notes
|
||||
microsoft/MiniLM-L12-H384-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
albert-base-v2,True,True,11M,"nlp;bert-variant;transformer-encoder","12 layers; 128 embedding dim; 768 hidden dim; 12 attention heads; Smaller than BERTbase (11M params vs 109M params); Uses weight sharing to reduce # params but computational cost is similar to BERT."
|
||||
bert-base-uncased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-uncased_fp16,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
distilbert-base-uncased,True,True,66M,"nlp;bert-variant;transformer-encoder","Smaller and faster than BERT with 97percent retained accuracy."
|
||||
google/mobilebert-uncased,True,True,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,True,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,True,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet50_fp16,False,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,True,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
squeezenet1_0,False,True,1.25M,"cnn,image-classification,mobile,parallel-layers","Parallel conv2d (1x1 conv to compress -> (3x3 expand | 1x1 expand) -> concat)"
|
||||
wide_resnet50_2,False,True,69M,"cnn,image-classification,residuals,resnet-variant","Resnet variant where model depth is decreased and width is increased."
|
||||
|
||||
|
@@ -12,6 +12,7 @@ vision_models = [
|
||||
"resnet101",
|
||||
"resnet18",
|
||||
"resnet50",
|
||||
"resnet50_fp16",
|
||||
"squeezenet1_0",
|
||||
"wide_resnet50_2",
|
||||
"mobilenet_v3_small",
|
||||
@@ -31,6 +32,8 @@ def get_torch_model(modelname):
|
||||
return get_vision_model(modelname)
|
||||
elif modelname in hf_img_cls_models:
|
||||
return get_hf_img_cls_model(modelname)
|
||||
elif "fp16" in modelname:
|
||||
return get_fp16_model(modelname)
|
||||
else:
|
||||
return get_hf_model(modelname)
|
||||
|
||||
@@ -114,7 +117,6 @@ class HuggingFaceLanguage(torch.nn.Module):
|
||||
def get_hf_model(name):
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
model = HuggingFaceLanguage(name)
|
||||
@@ -146,6 +148,7 @@ def get_vision_model(torch_model):
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
@@ -153,10 +156,26 @@ def get_vision_model(torch_model):
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
fp16_model = None
|
||||
if "fp16" in torch_model:
|
||||
fp16_model = True
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
device=torch.device("cuda"), dtype=torch.half
|
||||
)
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
model_fp16.to("cuda")
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
model, test_input, actual_out = (
|
||||
model_fp16,
|
||||
test_input_fp16,
|
||||
actual_out_fp16,
|
||||
)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
@@ -164,6 +183,49 @@ def get_vision_model(torch_model):
|
||||
|
||||
####################### Other PyTorch HF Models ###############################
|
||||
|
||||
|
||||
class BertHalfPrecisionModel(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
from transformers import AutoModelForMaskedLM
|
||||
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_fp16_model(torch_model):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
modelname = torch_model.replace("_fp16", "")
|
||||
model = BertHalfPrecisionModel(modelname)
|
||||
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
||||
text = "Replace me by any text you like."
|
||||
test_input_fp16 = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
return_tensors="pt",
|
||||
).input_ids.to("cuda")
|
||||
# test_input = torch.randint(2, (1, 128))
|
||||
# test_input_fp16 = test_input.to(
|
||||
# device=torch.device("cuda")
|
||||
# )
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
with torch.no_grad():
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
return model_fp16, test_input_fp16, actual_out_fp16
|
||||
|
||||
|
||||
# Utility function for comparing two tensors (torch).
|
||||
def compare_tensors(torch_tensor, numpy_tensor, rtol=1e-02, atol=1e-03):
|
||||
# torch_to_numpy = torch_tensor.detach().numpy()
|
||||
|
||||
@@ -16,3 +16,5 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,22M,"image-class
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
|
||||
|
@@ -16,6 +16,7 @@ from models.stable_diffusion.stable_args import args
|
||||
from models.stable_diffusion.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
import gc
|
||||
|
||||
|
||||
model_config = {
|
||||
@@ -81,16 +82,25 @@ class ModelCache:
|
||||
self.version = None
|
||||
self.schedulers = None
|
||||
self.tokenizer = None
|
||||
self.vae = None
|
||||
self.clip = None
|
||||
self.unet = None
|
||||
|
||||
def set_models(self, device_key):
|
||||
if self.device != device_key or self.variant != args.variant:
|
||||
self.device = device_key
|
||||
self.variant = args.variant
|
||||
self.version = args.version
|
||||
args.device = device_key.split("=>", 1)[0].strip()
|
||||
args.device = device_key.split("=>", 1)[1].strip()
|
||||
args.max_length = 64
|
||||
args.use_tuned = True
|
||||
set_init_device_flags()
|
||||
del self.schedulers
|
||||
del self.tokenizer
|
||||
del self.vae
|
||||
del self.unet
|
||||
del self.clip
|
||||
gc.collect()
|
||||
self.schedulers = get_schedulers(args.version)
|
||||
self.tokenizer = get_tokenizer(args.version)
|
||||
self.vae = get_vae()
|
||||
|
||||
@@ -9,6 +9,10 @@ from random import randint
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
if args.clear_all:
|
||||
@@ -65,6 +69,39 @@ def set_ui_params(
|
||||
args.variant = variant
|
||||
|
||||
|
||||
# save output images and the inputs correspoding to it.
|
||||
def save_output_img(output_img):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(output_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": args.variant,
|
||||
"VERSION": args.version,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": args.seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -220,9 +257,12 @@ def stable_diff_inf(
|
||||
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nvariant={args.variant}, scheduler={args.scheduler}, device={device_key}"
|
||||
text_output += f"\nvariant={args.variant}, version={args.version}, scheduler={args.scheduler}"
|
||||
text_output += f"\ndevice={device_key}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={height}x{width}"
|
||||
text_output += f"\nAverage step time: {avg_ms:.4f}ms/it"
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(pil_images[0])
|
||||
|
||||
return pil_images[0], text_output
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from models.stable_diffusion.utils import compile_through_fx
|
||||
from models.stable_diffusion.resources import models_config
|
||||
from models.stable_diffusion.stable_args import args
|
||||
import torch
|
||||
|
||||
model_config = {
|
||||
"v2_1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1_4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
# clip has 2 variants of max length 77 or 64.
|
||||
model_clip_max_length = 64 if args.max_length == 64 else 77
|
||||
@@ -17,14 +13,6 @@ if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
elif args.variant == "openjourney":
|
||||
model_clip_max_length = 64
|
||||
|
||||
model_variant = {
|
||||
"stablediffusion": "SD",
|
||||
"anythingv3": "Linaqruf/anything-v3.0",
|
||||
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"openjourney": "prompthero/openjourney",
|
||||
"analogdiffusion": "wavymulder/Analog-Diffusion",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2_1": {
|
||||
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
|
||||
@@ -58,45 +46,34 @@ model_input = {
|
||||
},
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = {
|
||||
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
|
||||
"anythingv3": "diffusers",
|
||||
"analogdiffusion": "main",
|
||||
"openjourney": "main",
|
||||
"dreamlike": "main",
|
||||
}
|
||||
version = args.version if args.variant == "stablediffusion" else "v1_4"
|
||||
|
||||
|
||||
def get_configs():
|
||||
model_id_key = f"{args.variant}/{version}"
|
||||
revision_key = f"{args.variant}/{args.precision}"
|
||||
try:
|
||||
model_id = models_config[0][model_id_key]
|
||||
revision = models_config[1][revision_key]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"No entry for {model_id_key} or {revision_key} in the models configuration"
|
||||
)
|
||||
|
||||
return model_id, revision
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.variant == "stablediffusion":
|
||||
if args.version != "v1_4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_variant[args.variant],
|
||||
subfolder="text_encoder",
|
||||
revision=model_revision[args.variant],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -104,23 +81,44 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input[args.version]["clip"],
|
||||
model_input[version]["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_shark_module(model_key, module, model_name, extra_args):
|
||||
if args.precision == "fp16":
|
||||
module = module.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[version][model_key]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[version][model_key]
|
||||
|
||||
shark_module = compile_through_fx(
|
||||
module,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class BaseVaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -128,52 +126,19 @@ def get_base_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return (x / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
vae = BaseVaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
return get_shark_module("vae", vae, model_name, extra_args)
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
revision=model_revision[args.variant],
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -184,52 +149,19 @@ def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda()
|
||||
for inputs in model_input[args.version]["vae"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["vae"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
vae = vae.half().cuda()
|
||||
inputs = tuple(
|
||||
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["vae"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} not yet added")
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
return get_shark_module("vae", vae, model_name, extra_args)
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
model_id, revision = get_configs()
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_config[args.version]
|
||||
if args.variant == "stablediffusion"
|
||||
else model_variant[args.variant],
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
revision=model_revision[args.variant],
|
||||
revision=revision,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -247,39 +179,4 @@ def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
if args.variant == "stablediffusion":
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input[args.version]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
elif args.variant in [
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]:
|
||||
if args.precision == "fp16":
|
||||
unet = unet.half().cuda()
|
||||
inputs = tuple(
|
||||
[
|
||||
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
|
||||
for inputs in model_input["v1_4"]["unet"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
inputs = model_input["v1_4"]["unet"]
|
||||
else:
|
||||
raise ValueError(f"{args.variant} is not yet added")
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
return get_shark_module("unet", unet, model_name, extra_args)
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket}/{model_key} is not present in the models database"
|
||||
f" there is no entry for {model_key} in the models database"
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -29,3 +29,13 @@ if os.path.exists(models_loc):
|
||||
|
||||
if len(models_db) != 3:
|
||||
sys.exit("Error: Unable to load models database.")
|
||||
|
||||
|
||||
models_config = []
|
||||
modelconfig_loc = resource_path("resources/model_config.json")
|
||||
if os.path.exists(modelconfig_loc):
|
||||
with open(modelconfig_loc, encoding="utf-8") as fopen:
|
||||
models_config = json.load(fopen)
|
||||
|
||||
if len(models_config) != 2:
|
||||
sys.exit("Error: Unable to load models configuration.")
|
||||
|
||||
21
web/models/stable_diffusion/resources/model_config.json
Normal file
21
web/models/stable_diffusion/resources/model_config.json
Normal file
@@ -0,0 +1,21 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
@@ -12,7 +12,6 @@
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_1dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
|
||||
@@ -117,6 +117,13 @@ p.add_argument(
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -180,7 +180,9 @@ def set_init_device_flags():
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
if args.version == "v1_4":
|
||||
args.max_length = 77
|
||||
elif args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
@@ -189,6 +191,7 @@ def set_init_device_flags():
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or args.version == "v1_4"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
@@ -217,7 +220,7 @@ def get_available_devices():
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
device_list.append(f"{device['name']} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
@@ -227,5 +230,5 @@ def get_available_devices():
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
# available_devices.append("cpu")
|
||||
return available_devices
|
||||
|
||||
@@ -26,6 +26,7 @@ datas += collect_data_files('shark')
|
||||
datas += [
|
||||
( 'models/stable_diffusion/resources/prompts.json', 'resources' ),
|
||||
( 'models/stable_diffusion/resources/model_db.json', 'resources' ),
|
||||
( 'models/stable_diffusion/resources/model_config.json', 'resources' ),
|
||||
( 'models/stable_diffusion/logos/*', 'logos' )
|
||||
]
|
||||
|
||||
|
||||
240
web/telegram_bot.py
Normal file
240
web/telegram_bot.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import logging
|
||||
import os
|
||||
from models.stable_diffusion.main import stable_diff_inf
|
||||
from models.stable_diffusion.utils import get_available_devices
|
||||
from dotenv import load_dotenv
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram import BotCommand
|
||||
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
|
||||
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
|
||||
from io import BytesIO
|
||||
import random
|
||||
|
||||
log = logging.getLogger("TG.Bot")
|
||||
logging.basicConfig()
|
||||
log.warning("Start")
|
||||
load_dotenv()
|
||||
os.environ["AMD_ENABLE_LLPC"] = "0"
|
||||
TG_TOKEN = os.getenv("TG_TOKEN")
|
||||
SELECTED_MODEL = "stablediffusion"
|
||||
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
|
||||
STEPS = 30
|
||||
NEGATIVE_PROMPT = (
|
||||
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
|
||||
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
|
||||
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
|
||||
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
|
||||
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
|
||||
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
|
||||
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
|
||||
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
|
||||
" legs, fused fingers, too many fingers"
|
||||
)
|
||||
GUIDANCE_SCALE = 6
|
||||
available_devices = get_available_devices()
|
||||
models_list = [
|
||||
"stablediffusion",
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]
|
||||
sheds_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
|
||||
def image_to_bytes(image):
|
||||
bio = BytesIO()
|
||||
bio.name = "image.jpeg"
|
||||
image.save(bio, "JPEG")
|
||||
bio.seek(0)
|
||||
return bio
|
||||
|
||||
|
||||
def get_try_again_markup():
|
||||
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
return reply_markup
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
seed = random.randint(1, 10000)
|
||||
log.warning(SELECTED_MODEL)
|
||||
log.warning(STEPS)
|
||||
image, text = stable_diff_inf(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
steps=STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
seed=seed,
|
||||
scheduler_key=SELECTED_SCHEDULER,
|
||||
variant=SELECTED_MODEL,
|
||||
device_key=available_devices[0],
|
||||
)
|
||||
|
||||
return image, seed
|
||||
|
||||
|
||||
async def generate_and_send_photo(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
progress_msg = await update.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=update.message.message_id
|
||||
)
|
||||
im, seed = generate_image(prompt=update.message.text)
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{update.message.text}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=update.message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
if query.data in models_list:
|
||||
global SELECTED_MODEL
|
||||
SELECTED_MODEL = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected model: {query.data}")
|
||||
return
|
||||
if query.data in sheds_list:
|
||||
global SELECTED_SCHEDULER
|
||||
SELECTED_SCHEDULER = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
|
||||
return
|
||||
replied_message = query.message.reply_to_message
|
||||
await query.answer()
|
||||
progress_msg = await query.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=replied_message.message_id
|
||||
)
|
||||
|
||||
if query.data == "TRYAGAIN":
|
||||
prompt = replied_message.text
|
||||
im, seed = generate_image(prompt)
|
||||
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{prompt}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=replied_message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def select_model_handler(update, context):
|
||||
text = "Select model"
|
||||
keyboard = []
|
||||
for model in models_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=model, callback_data=model),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def select_scheduler_handler(update, context):
|
||||
text = "Select schedule"
|
||||
keyboard = []
|
||||
for shed in sheds_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=shed, callback_data=shed),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def set_steps_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_steps ")[1]
|
||||
global STEPS
|
||||
STEPS = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_steps 30"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_negative_prompt_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_negative_prompt ")[1]
|
||||
global NEGATIVE_PROMPT
|
||||
NEGATIVE_PROMPT = input_args
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_negative_prompt ugly, bad art, mutated"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_guidance_scale_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_guidance_scale ")[1]
|
||||
global GUIDANCE_SCALE
|
||||
GUIDANCE_SCALE = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_guidance_scale 7"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def setup_bot_commands(application: Application) -> None:
|
||||
await application.bot.set_my_commands(
|
||||
[
|
||||
BotCommand("select_model", "to select model"),
|
||||
BotCommand("select_scheduler", "to select scheduler"),
|
||||
BotCommand("set_steps", "to set steps"),
|
||||
BotCommand("set_guidance_scale", "to set guidance scale"),
|
||||
BotCommand("set_negative_prompt", "to set negative prompt"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
app = (
|
||||
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
|
||||
)
|
||||
app.add_handler(CommandHandler("select_model", select_model_handler))
|
||||
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
|
||||
app.add_handler(CommandHandler("set_steps", set_steps_handler))
|
||||
app.add_handler(
|
||||
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
|
||||
)
|
||||
app.add_handler(CallbackQueryHandler(button))
|
||||
log.warning("Start bot")
|
||||
app.run_polling()
|
||||
Reference in New Issue
Block a user