mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
24 Commits
20230320.6
...
20230328.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
594c6b8ea2 | ||
|
|
96b1560da5 | ||
|
|
0ef6a0e234 | ||
|
|
641d535f44 | ||
|
|
5bb7846227 | ||
|
|
8f84258fb8 | ||
|
|
7619e76bbd | ||
|
|
9267eadbfa | ||
|
|
431132b8ee | ||
|
|
fb35e13e7a | ||
|
|
17a67897d1 | ||
|
|
da449b73aa | ||
|
|
0b0526699a | ||
|
|
4fac46f7bb | ||
|
|
49925950f1 | ||
|
|
807947c0c8 | ||
|
|
593428bda4 | ||
|
|
cede9b4fec | ||
|
|
c2360303f0 | ||
|
|
420366c1b8 | ||
|
|
d31bae488c | ||
|
|
c23fcf3748 | ||
|
|
7dbbb1726a | ||
|
|
8b8cc7fd33 |
4
.github/workflows/test-models.yml
vendored
4
.github/workflows/test-models.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
@@ -158,4 +158,6 @@ jobs:
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
@@ -114,12 +114,12 @@ source shark.venv/bin/activate
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -15,8 +16,6 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -25,15 +24,15 @@ init_import_mlir = args.import_mlir
|
||||
|
||||
# For stencil, the input image can be of any size but we need to ensure that
|
||||
# it conforms with our model contraints :-
|
||||
# Both width and height should be > 384 and multiple of 8.
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 384:
|
||||
n_size = 384
|
||||
if min_size < 128:
|
||||
n_size = 128
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
@@ -46,6 +45,22 @@ def resize_stencil(image: Image.Image):
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
|
||||
min_size = min(width, height)
|
||||
if min_size > 768:
|
||||
n_size = 768
|
||||
if width == min_size:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
else:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
return new_image, n_width, n_height
|
||||
|
||||
@@ -77,6 +92,7 @@ def img2img_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -84,8 +100,6 @@ def img2img_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -100,10 +114,6 @@ def img2img_inf(
|
||||
image = init_image.convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -118,14 +128,9 @@ def img2img_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -159,7 +164,7 @@ def img2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
if (
|
||||
@@ -183,8 +188,8 @@ def img2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
args.use_tuned = False
|
||||
@@ -205,7 +210,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -225,11 +230,11 @@ def img2img_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -274,7 +279,7 @@ def img2img_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -380,3 +385,7 @@ if __name__ == "__main__":
|
||||
extra_info = {"STRENGTH": args.strength}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -48,6 +47,7 @@ def inpaint_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -55,8 +55,6 @@ def inpaint_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -66,10 +64,6 @@ def inpaint_inf(
|
||||
args.mask_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -84,14 +78,9 @@ def inpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -108,7 +97,7 @@ def inpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -133,14 +122,15 @@ def inpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
@@ -148,14 +138,13 @@ def inpaint_inf(
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -202,7 +191,7 @@ def inpaint_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -232,6 +221,7 @@ if __name__ == "__main__":
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
@@ -239,7 +229,6 @@ if __name__ == "__main__":
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
@@ -285,3 +274,7 @@ if __name__ == "__main__":
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
19
apps/stable_diffusion/scripts/main.py
Normal file
19
apps/stable_diffusion/scripts/main.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.scripts import (
|
||||
img2img,
|
||||
txt2img,
|
||||
# inpaint,
|
||||
# outpaint,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.app == "txt2img":
|
||||
txt2img.main()
|
||||
elif args.app == "img2img":
|
||||
img2img.main()
|
||||
# elif args.app == "inpaint":
|
||||
# inpaint.main()
|
||||
# elif args.app == "outpaint":
|
||||
# outpaint.main()
|
||||
else:
|
||||
print(f"args.app value is {args.app} but this isn't supported")
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -51,6 +50,7 @@ def outpaint_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -58,8 +58,6 @@ def outpaint_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -68,10 +66,6 @@ def outpaint_inf(
|
||||
args.img_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -86,14 +80,9 @@ def outpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -110,7 +99,7 @@ def outpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -135,8 +124,8 @@ def outpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
@@ -151,11 +140,11 @@ def outpaint_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -211,7 +200,7 @@ def outpaint_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -310,3 +299,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -159,9 +159,6 @@ class LoraDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
def lora_train(
|
||||
prompt: str,
|
||||
@@ -187,8 +184,6 @@ def lora_train(
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
|
||||
print(
|
||||
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
||||
)
|
||||
@@ -227,7 +222,12 @@ def lora_train(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device
|
||||
device_str = device.split("=>", 1)[1].strip().split("://")
|
||||
if len(device_str) > 1:
|
||||
device_str = device_str[0] + ":" + device_str[1]
|
||||
else:
|
||||
device_str = device_str[0]
|
||||
args.device = device_str
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -11,7 +12,6 @@ from apps.stable_diffusion.src import (
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -43,6 +43,7 @@ def txt2img_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
@@ -50,8 +51,6 @@ def txt2img_inf(
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -59,10 +58,6 @@ def txt2img_inf(
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -80,14 +75,9 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -101,7 +91,7 @@ def txt2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -127,8 +117,8 @@ def txt2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
@@ -145,11 +135,11 @@ def txt2img_inf(
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -190,7 +180,7 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -200,7 +190,6 @@ if __name__ == "__main__":
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
use_lora = args.use_lora
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
@@ -216,7 +205,8 @@ if __name__ == "__main__":
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
@@ -256,3 +246,7 @@ if __name__ == "__main__":
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
@@ -12,8 +13,6 @@ from apps.stable_diffusion.src import (
|
||||
)
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -41,15 +40,16 @@ def upscaler_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
@@ -62,10 +62,6 @@ def upscaler_inf(
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -83,6 +79,10 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
@@ -97,7 +97,7 @@ def upscaler_inf(
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -118,8 +118,8 @@ def upscaler_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
@@ -135,11 +135,14 @@ def upscaler_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
|
||||
"DDPM"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -232,6 +235,7 @@ if __name__ == "__main__":
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/txt2img.py'],
|
||||
['scripts/main.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
|
||||
@@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,7 +100,8 @@ class SharkifyStableDiffusionModel:
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
use_lora: str = ""
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -107,6 +109,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -200,6 +203,7 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
@@ -255,6 +259,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
@@ -281,13 +286,14 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_controlled_unet(self):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -295,6 +301,8 @@ class SharkifyStableDiffusionModel:
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -333,6 +341,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_controlled_unet
|
||||
|
||||
@@ -386,6 +395,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_cnet
|
||||
|
||||
@@ -399,7 +409,7 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
self.unet.load_attn_procs(use_lora)
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
@@ -444,6 +454,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -481,18 +492,21 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -512,6 +526,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
@@ -539,6 +554,7 @@ class SharkifyStableDiffusionModel:
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
|
||||
self.base_model_id = base_model_id
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
@@ -556,7 +572,12 @@ class SharkifyStableDiffusionModel:
|
||||
compiled_controlnet = self.get_control_net()
|
||||
compiled_controlled_unet = self.get_controlled_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
if self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
compiled_unet = get_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
|
||||
@@ -20,6 +20,15 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
bucket_key = "gs://shark_tank/prashant_nod"
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
@@ -41,6 +50,12 @@ def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
|
||||
# TODO: Get the quantize model from model_db.json
|
||||
if args.use_quantize == "int8":
|
||||
bk, mk, flags = get_quantize_model()
|
||||
return get_shark_model(bk, mk, flags)
|
||||
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
|
||||
@@ -321,6 +321,7 @@ class StableDiffusionPipeline:
|
||||
use_stencil: str = None,
|
||||
use_lora: str = "",
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
is_inpaint = cls.__name__ in [
|
||||
"InpaintPipeline",
|
||||
@@ -349,6 +350,7 @@ class StableDiffusionPipeline:
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
|
||||
@@ -33,4 +33,5 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
@@ -76,18 +76,19 @@ def load_winograd_configs():
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
def load_lower_configs(base_model_id=None):
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
if not base_model_id:
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
@@ -114,7 +115,14 @@ def load_lower_configs():
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
and args.width == 768
|
||||
):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
|
||||
@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
@@ -340,6 +346,14 @@ p.add_argument(
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
|
||||
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -9,6 +9,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
@@ -95,6 +97,7 @@ def compile_through_fx(
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -116,19 +119,21 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, model_name, base_model_id
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
|
||||
if generate_vmfb:
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
@@ -264,8 +269,9 @@ def set_init_device_flags():
|
||||
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
or args.height not in [512, 768]
|
||||
or (args.height == 512 and args.width != 512)
|
||||
or (args.height == 768 and args.width != 768)
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
@@ -299,6 +305,20 @@ def set_init_device_flags():
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height == 768
|
||||
and args.width == 768
|
||||
and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
else:
|
||||
@@ -368,7 +388,7 @@ def get_available_devices():
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
available_devices.append("device => cpu")
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -454,7 +474,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
@@ -464,6 +484,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
@@ -66,7 +67,7 @@ from apps.stable_diffusion.web.ui import (
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj.init()
|
||||
global_obj._init()
|
||||
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
@@ -94,6 +95,8 @@ with gr.Blocks(
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
lora_train_web.render()
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ footer {
|
||||
|
||||
/* Hide "remove buttons" from ui dropdowns */
|
||||
#custom_model .token-remove.remove-all,
|
||||
#lora_weights .token-remove.remove-all,
|
||||
#scheduler .token-remove.remove-all,
|
||||
#device .token-remove.remove-all,
|
||||
#stencil_model .token-remove.remove-all {
|
||||
|
||||
@@ -77,10 +77,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -72,10 +72,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -69,10 +69,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -73,10 +73,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
|
||||
@@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -226,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
|
||||
@@ -74,25 +74,50 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_custom_model_path():
|
||||
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
|
||||
def get_custom_model_path(model="models"):
|
||||
prefix = Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd())
|
||||
match model:
|
||||
case "models":
|
||||
return Path(prefix, "models")
|
||||
case "vae":
|
||||
return Path(prefix, "models/vae")
|
||||
case "lora":
|
||||
return Path(prefix, "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name):
|
||||
return os.path.join(get_custom_model_path(), custom_model_name)
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files():
|
||||
def get_custom_model_files(model="models"):
|
||||
ckpt_files = []
|
||||
for extn in custom_model_filetypes:
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
file_types = custom_model_filetypes + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
|
||||
for x in glob.glob(
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
return use_weight
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
|
||||
@@ -8,49 +8,64 @@ Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
|
||||
|
||||
def init():
|
||||
global sd_obj
|
||||
global config_obj
|
||||
sd_obj = None
|
||||
config_obj = None
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
|
||||
def set_sd_obj(value):
|
||||
global sd_obj
|
||||
sd_obj = value
|
||||
global _sd_obj
|
||||
_sd_obj = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global config_obj
|
||||
config_obj = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global sd_obj
|
||||
sd_obj.scheduler = value
|
||||
|
||||
|
||||
def get_sd_obj():
|
||||
return sd_obj
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
return config_obj
|
||||
def set_sd_scheduler(key):
|
||||
global _sd_obj
|
||||
_sd_obj.scheduler = _schedulers[key]
|
||||
|
||||
|
||||
def set_sd_status(value):
|
||||
global sd_obj
|
||||
sd_obj.status = value
|
||||
global _sd_obj
|
||||
_sd_obj.status = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global _config_obj
|
||||
_config_obj = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
|
||||
|
||||
def get_sd_obj():
|
||||
return _sd_obj
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
global sd_obj
|
||||
return sd_obj.status
|
||||
return _sd_obj.status
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
return _config_obj
|
||||
|
||||
|
||||
def get_scheduler(key):
|
||||
return _schedulers[key]
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global sd_obj
|
||||
global config_obj
|
||||
del sd_obj
|
||||
del config_obj
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
del _sd_obj
|
||||
del _config_obj
|
||||
del _schedulers
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]
|
||||
counter = 0
|
||||
for import_opt in import_options:
|
||||
for model_name in hf_model_names:
|
||||
if model_name in to_skip:
|
||||
continue
|
||||
for use_tune in tuned_options:
|
||||
if (
|
||||
model_name == "stabilityai/stable-diffusion-2-1"
|
||||
and use_tune == tuned_options[0]
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
model_name == "stabilityai/stable-diffusion-2-1-base"
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
@@ -174,9 +185,23 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
if "2_1_base" in model_name:
|
||||
with open(dumpfile_name, "r+") as f:
|
||||
output = f.readlines()
|
||||
print("\n".join(output))
|
||||
if model_name == "CompVis/stable-diffusion-v1-4":
|
||||
print("failed a known successful model.")
|
||||
exit(1)
|
||||
if os.name == "nt":
|
||||
counter += 1
|
||||
if counter % 2 == 0:
|
||||
extra_flags.append(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
else:
|
||||
if counter != 1:
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
|
||||
@@ -70,3 +70,9 @@ def pytest_addoption(parser):
|
||||
default="./temp_dispatch_benchmarks",
|
||||
help="Directory in which dispatch benchmarks are saved.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--batchsize",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Batch size for the tested model.",
|
||||
)
|
||||
|
||||
@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
|
||||
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "Pix2Pix" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "Pix2Pix" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
pix2pix_init = Path(
|
||||
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "StableDiffusionPix2PixZeroPipeline" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
if path_to_tranformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
--pre
|
||||
|
||||
numpy>1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
torchvision==0.16.0.dev20230322
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
@@ -15,8 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras>=2.10
|
||||
tensorflow>2.11
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@main
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
|
||||
@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
break
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
|
||||
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
|
||||
{
|
||||
Write-Host "Please install Python 3.11 and try again"
|
||||
break
|
||||
exit 34
|
||||
}
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
echo "Could not install torch + cu117." >&2
|
||||
echo "Could not install torch + cu118." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg"]:
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
return [
|
||||
|
||||
@@ -30,11 +30,10 @@ def get_iree_gpu_args():
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
return []
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
|
||||
@@ -78,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
@@ -112,24 +113,31 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
print(
|
||||
"Currently disabled TensorFloat32 calculations in pytorch benchmarks."
|
||||
)
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname)[:2]
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
@@ -178,7 +186,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
model,
|
||||
input,
|
||||
) = get_tf_model(
|
||||
modelname
|
||||
modelname, self.import_args
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
@@ -338,11 +346,19 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
return comp_str
|
||||
|
||||
def benchmark_all_csv(
|
||||
self, inputs: tuple, modelname, dynamic, device_str, frontend
|
||||
self,
|
||||
inputs: tuple,
|
||||
modelname,
|
||||
dynamic,
|
||||
device_str,
|
||||
frontend,
|
||||
import_args,
|
||||
):
|
||||
self.setup_cl(inputs)
|
||||
self.import_args = import_args
|
||||
field_names = [
|
||||
"model",
|
||||
"batch_size",
|
||||
"engine",
|
||||
"dialect",
|
||||
"device",
|
||||
@@ -375,6 +391,7 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_info = {}
|
||||
bench_info["model"] = modelname
|
||||
bench_info["batch_size"] = str(import_args["batch_size"])
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = shark_args.num_iterations
|
||||
if dynamic == True:
|
||||
|
||||
@@ -139,11 +139,21 @@ def download_model(
|
||||
tank_url="gs://shark_tank/latest",
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args={"batch_size": "1"},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
if import_args["batch_size"] != 1:
|
||||
model_dir_name = (
|
||||
model_name
|
||||
+ "_"
|
||||
+ frontend
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
@@ -201,7 +211,9 @@ def download_model(
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
tank_dir = WORKDIR
|
||||
gen_shark_files(model_name, frontend, tank_dir)
|
||||
gen_shark_files(model_name, frontend, tank_dir, import_args)
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
|
||||
@@ -35,18 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -146,7 +146,6 @@ if __name__ == "__main__":
|
||||
backend_config = "cuda"
|
||||
args = [
|
||||
"--iree-cuda-llvm-target-arch=sm_80",
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
"--iree-enable-fusion-with-reduction-ops",
|
||||
]
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -33,7 +33,7 @@ def create_hash(file_name):
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list, local_tank_cache):
|
||||
def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
@@ -59,7 +59,6 @@ def save_torch_model(torch_model_list, local_tank_cache):
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.use_tuned = False
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
@@ -75,6 +74,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
@@ -82,19 +82,33 @@ def save_torch_model(torch_model_list, local_tank_cache):
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
model, input, _ = get_vision_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
model, input, _ = get_hf_model(torch_model_name, import_args)
|
||||
elif model_type == "hf_seq2seq":
|
||||
model, input, _ = get_hf_seq2seq_model(torch_model_name)
|
||||
model, input, _ = get_hf_seq2seq_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
model, input, _ = get_fp16_model(torch_model_name, import_args)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache, str(torch_model_name) + "_torch"
|
||||
)
|
||||
if import_args["batch_size"] != 1:
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(torch_model_name)
|
||||
+ "_torch"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
@@ -118,7 +132,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache):
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
@@ -150,20 +164,38 @@ def save_tf_model(tf_model_list, local_tank_cache):
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
model, input, _ = get_masked_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
model, input, _ = get_causal_image_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
model, input, _ = get_keras_model(tf_model_name, import_args)
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
model, input, _ = get_TFhf_model(tf_model_name, import_args)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
|
||||
model, input, _ = get_tfhf_seq2seq_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
if import_args["batch_size"] != 1:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(tf_model_name)
|
||||
+ "_tf"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
@@ -175,13 +207,9 @@ def save_tf_model(tf_model_list, local_tank_cache):
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
@@ -198,13 +226,13 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
# Preprocess to get SharkImporter input import_args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input args
|
||||
# Use SharkImporter to get SharkInference input import_args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
@@ -228,43 +256,69 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
)
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir):
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
has_pkgs = False
|
||||
if frontend == "torch":
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class NoImportException(Exception):
|
||||
"Raised when requirements are not met for OTF model artifact generation."
|
||||
pass
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
|
||||
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
||||
# TODO: Add TFlite support.
|
||||
import tempfile
|
||||
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir)
|
||||
import_args = importer_args
|
||||
if check_requirements(frontend):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir, import_args)
|
||||
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir)
|
||||
elif frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir, import_args)
|
||||
else:
|
||||
raise NoImportException
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
@@ -276,7 +330,7 @@ def is_valid_file(arg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
||||
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument(
|
||||
# "--torch_model_csv",
|
||||
@@ -304,8 +358,11 @@ if __name__ == "__main__":
|
||||
# )
|
||||
# parser.add_argument("--upload", type=bool, default=False)
|
||||
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
# old_import_args = parser.parse_import_args()
|
||||
import_args = {
|
||||
"batch_size": "1",
|
||||
}
|
||||
print(import_args)
|
||||
home = str(Path.home())
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
@@ -319,7 +376,8 @@ if __name__ == "__main__":
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
import_args,
|
||||
)
|
||||
save_torch_model(torch_model_csv, WORKDIR)
|
||||
save_tf_model(tf_model_csv, WORKDIR)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR)
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -35,17 +34,17 @@ hf_seq2seq_models = [
|
||||
]
|
||||
|
||||
|
||||
def get_torch_model(modelname):
|
||||
def get_torch_model(modelname, import_args):
|
||||
if modelname in vision_models:
|
||||
return get_vision_model(modelname)
|
||||
return get_vision_model(modelname, import_args)
|
||||
elif modelname in hf_img_cls_models:
|
||||
return get_hf_img_cls_model(modelname)
|
||||
return get_hf_img_cls_model(modelname, import_args)
|
||||
elif modelname in hf_seq2seq_models:
|
||||
return get_hf_seq2seq_model(modelname)
|
||||
return get_hf_seq2seq_model(modelname, import_args)
|
||||
elif "fp16" in modelname:
|
||||
return get_fp16_model(modelname)
|
||||
return get_fp16_model(modelname, import_args)
|
||||
else:
|
||||
return get_hf_model(modelname)
|
||||
return get_hf_model(modelname, import_args)
|
||||
|
||||
|
||||
##################### Hugging Face Image Classification Models ###################################
|
||||
@@ -88,14 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module):
|
||||
return self.model.forward(inputs)[0]
|
||||
|
||||
|
||||
def get_hf_img_cls_model(name):
|
||||
def get_hf_img_cls_model(name, import_args):
|
||||
model = HuggingFaceImageClassification(name)
|
||||
# you can use preprocess_input_image to get the test_input or just random value.
|
||||
test_input = preprocess_input_image(name)
|
||||
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
|
||||
# print("test_input.shape: ", test_input.shape)
|
||||
# test_input.shape: torch.Size([1, 3, 224, 224])
|
||||
test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1)
|
||||
test_input = test_input.repeat(import_args["batch_size"], 1, 1, 1)
|
||||
actual_out = model(test_input)
|
||||
# print("actual_out.shape: ", actual_out.shape)
|
||||
# actual_out.shape: torch.Size([1, 1000])
|
||||
@@ -125,14 +124,13 @@ class HuggingFaceLanguage(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_model(name):
|
||||
def get_hf_model(name, import_args):
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
)
|
||||
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (BATCH_SIZE, 128))
|
||||
test_input = torch.randint(2, (import_args["batch_size"], 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -165,7 +163,7 @@ class HFSeq2SeqLanguageModel(torch.nn.Module):
|
||||
)[0]
|
||||
|
||||
|
||||
def get_hf_seq2seq_model(name):
|
||||
def get_hf_seq2seq_model(name, import_args):
|
||||
m = HFSeq2SeqLanguageModel(name)
|
||||
encoded_input_ids = m.preprocess_input(
|
||||
"Studies have been shown that owning a dog is good for you"
|
||||
@@ -193,7 +191,7 @@ class VisionModule(torch.nn.Module):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
def get_vision_model(torch_model):
|
||||
def get_vision_model(torch_model, import_args):
|
||||
import torchvision.models as models
|
||||
|
||||
default_image_size = (224, 224)
|
||||
@@ -239,7 +237,7 @@ def get_vision_model(torch_model):
|
||||
fp16_model = True
|
||||
torch_model, input_image_size = vision_models_dict[torch_model]
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(BATCH_SIZE, 3, 224, 224)
|
||||
test_input = torch.randn(import_args["batch_size"], 3, *input_image_size)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
@@ -280,14 +278,14 @@ class BertHalfPrecisionModel(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_fp16_model(torch_model):
|
||||
def get_fp16_model(torch_model, import_args):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
modelname = torch_model.replace("_fp16", "")
|
||||
model = BertHalfPrecisionModel(modelname)
|
||||
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
||||
text = "Replace me by any text you like."
|
||||
text = [text] * BATCH_SIZE
|
||||
text = [text] * import_args["batch_size"]
|
||||
test_input_fp16 = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
||||
@@ -52,19 +47,19 @@ img_models = [
|
||||
]
|
||||
|
||||
|
||||
def get_tf_model(name):
|
||||
def get_tf_model(name, import_args):
|
||||
if name in keras_models:
|
||||
return get_keras_model(name)
|
||||
return get_keras_model(name, import_args)
|
||||
elif name in maskedlm_models:
|
||||
return get_masked_lm_model(name)
|
||||
return get_masked_lm_model(name, import_args)
|
||||
elif name in causallm_models:
|
||||
return get_causal_lm_model(name)
|
||||
return get_causal_lm_model(name, import_args)
|
||||
elif name in tfhf_models:
|
||||
return get_TFhf_model(name)
|
||||
return get_TFhf_model(name, import_args)
|
||||
elif name in img_models:
|
||||
return get_causal_image_model(name)
|
||||
return get_causal_image_model(name, import_args)
|
||||
elif name in tfhf_seq2seq_models:
|
||||
return get_tfhf_seq2seq_model(name)
|
||||
return get_tfhf_seq2seq_model(name, import_args)
|
||||
else:
|
||||
raise Exception(
|
||||
"TF model not found! Please check that the modelname has been input correctly."
|
||||
@@ -72,6 +67,12 @@ def get_tf_model(name):
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Bert Models ###################################
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
BERT_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
@@ -104,7 +105,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
def get_TFhf_model(name):
|
||||
def get_TFhf_model(name, import_args):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
@@ -166,7 +167,6 @@ def preprocess_input(
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
@@ -196,7 +196,9 @@ class MaskedLM(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_masked_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
def get_masked_lm_model(
|
||||
hf_name, import_args, text="Hello, this is the default text."
|
||||
):
|
||||
model = MaskedLM(hf_name)
|
||||
encoded_input = preprocess_input(
|
||||
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
|
||||
@@ -251,7 +253,9 @@ class CausalLM(tf.Module):
|
||||
return self.model.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
def get_causal_lm_model(
|
||||
hf_name, import_args, text="Hello, this is the default text."
|
||||
):
|
||||
model = CausalLM(hf_name)
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input = model.preprocess_input(batched_text)
|
||||
@@ -306,7 +310,7 @@ class TFHFSeq2SeqLanguageModel(tf.Module):
|
||||
return self.model.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
def get_tfhf_seq2seq_model(name):
|
||||
def get_tfhf_seq2seq_model(name, import_args):
|
||||
m = TFHFSeq2SeqLanguageModel(name)
|
||||
text = "Studies have been shown that owning a dog is good for you"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
@@ -442,7 +446,7 @@ def load_image(path_to_image, width, height, channels):
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
def get_keras_model(modelname, import_args):
|
||||
if modelname == "efficientnet-v2-s":
|
||||
model = EfficientNetV2SModule()
|
||||
elif modelname == "efficientnet_b0":
|
||||
@@ -530,7 +534,7 @@ def preprocess_input_image(model_name):
|
||||
return [inputs[str(*inputs)]]
|
||||
|
||||
|
||||
def get_causal_image_model(hf_name):
|
||||
def get_causal_image_model(hf_name, import_args):
|
||||
model = AutoModelImageClassfication(hf_name)
|
||||
test_input = preprocess_input_image(hf_name)
|
||||
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=
|
||||
|
||||
@@ -8,6 +8,7 @@ from parameterized import parameterized
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
from tank.generate_sharktank import NoImportException
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -48,7 +49,9 @@ def load_csv_and_convert(filename, gen=False):
|
||||
)
|
||||
# This is a pytest workaround
|
||||
if gen:
|
||||
with open("tank/dict_configs.py", "w+") as out:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
|
||||
) as out:
|
||||
out.write("ALL = [\n")
|
||||
for c in model_configs:
|
||||
out.write(str(c) + ",\n")
|
||||
@@ -68,7 +71,9 @@ def get_valid_test_params():
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
# results in strange pytest failures.
|
||||
load_csv_and_convert("tank/all_models.csv", True)
|
||||
load_csv_and_convert(
|
||||
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
|
||||
)
|
||||
from tank.dict_configs import ALL
|
||||
|
||||
config_list = ALL
|
||||
@@ -135,6 +140,9 @@ class SharkModuleTester:
|
||||
self.config = config
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
import_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
@@ -161,11 +169,17 @@ class SharkModuleTester:
|
||||
if "winograd" in self.config["flags"]:
|
||||
shark_args.use_winograd = True
|
||||
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
try:
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
import_args=import_config,
|
||||
)
|
||||
except NoImportException:
|
||||
pytest.xfail(
|
||||
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
@@ -210,10 +224,12 @@ class SharkModuleTester:
|
||||
self.save_reproducers()
|
||||
|
||||
def benchmark_module(self, shark_module, inputs, dynamic, device):
|
||||
model_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
shark_args.enable_tf32 = False
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
@@ -222,6 +238,7 @@ class SharkModuleTester:
|
||||
dynamic,
|
||||
device,
|
||||
self.config["framework"],
|
||||
import_args=model_config,
|
||||
)
|
||||
|
||||
def save_reproducers(self):
|
||||
@@ -271,6 +288,9 @@ class SharkModuleTest(unittest.TestCase):
|
||||
@parameterized.expand(param_list, name_func=shark_test_name_func)
|
||||
def test_module(self, dynamic, device, config):
|
||||
self.module_tester = SharkModuleTester(config)
|
||||
self.module_tester.batch_size = self.pytestconfig.getoption(
|
||||
"batchsize"
|
||||
)
|
||||
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
|
||||
self.module_tester.save_repro = self.pytestconfig.getoption(
|
||||
"save_repro"
|
||||
@@ -342,6 +362,10 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
|
||||
)
|
||||
if config["framework"] == "tf" and self.module_tester.batch_size != 1:
|
||||
pytest.xfail(
|
||||
reason="Configurable batch sizes temp. unavailable for tensorflow models."
|
||||
)
|
||||
safe_name = (
|
||||
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
@@ -19,9 +21,3 @@ mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
|
||||
|
Reference in New Issue
Block a user