mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
41 Commits
20230316.6
...
20230328.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6f740b998 | ||
|
|
594c6b8ea2 | ||
|
|
96b1560da5 | ||
|
|
0ef6a0e234 | ||
|
|
641d535f44 | ||
|
|
5bb7846227 | ||
|
|
8f84258fb8 | ||
|
|
7619e76bbd | ||
|
|
9267eadbfa | ||
|
|
431132b8ee | ||
|
|
fb35e13e7a | ||
|
|
17a67897d1 | ||
|
|
da449b73aa | ||
|
|
0b0526699a | ||
|
|
4fac46f7bb | ||
|
|
49925950f1 | ||
|
|
807947c0c8 | ||
|
|
593428bda4 | ||
|
|
cede9b4fec | ||
|
|
c2360303f0 | ||
|
|
420366c1b8 | ||
|
|
d31bae488c | ||
|
|
c23fcf3748 | ||
|
|
7dbbb1726a | ||
|
|
8b8cc7fd33 | ||
|
|
e3c96a2b9d | ||
|
|
5e3f50647d | ||
|
|
7899e1803a | ||
|
|
d105246b9c | ||
|
|
90c958bca2 | ||
|
|
f99903e023 | ||
|
|
c6f44ef1b3 | ||
|
|
8dcd4d5aeb | ||
|
|
d319f4684e | ||
|
|
54d7b6d83e | ||
|
|
4a622532e5 | ||
|
|
650b2ada58 | ||
|
|
f87f8949f3 | ||
|
|
7dc9bf8148 | ||
|
|
ba48ff8d25 | ||
|
|
638840925c |
4
.github/workflows/test-models.yml
vendored
4
.github/workflows/test-models.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
@@ -158,4 +158,6 @@ jobs:
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
@@ -114,12 +114,12 @@ source shark.venv/bin/activate
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.11 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
|
||||
@@ -3,3 +3,4 @@ from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
|
||||
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
|
||||
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
|
||||
from apps.stable_diffusion.scripts.train_lora_word import lora_train
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -12,10 +13,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -24,15 +24,15 @@ init_import_mlir = args.import_mlir
|
||||
|
||||
# For stencil, the input image can be of any size but we need to ensure that
|
||||
# it conforms with our model contraints :-
|
||||
# Both width and height should be > 384 and multiple of 8.
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 384:
|
||||
n_size = 384
|
||||
if min_size < 128:
|
||||
n_size = 128
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
@@ -45,6 +45,22 @@ def resize_stencil(image: Image.Image):
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
|
||||
min_size = min(width, height)
|
||||
if min_size > 768:
|
||||
n_size = 768
|
||||
if width == min_size:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
else:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
return new_image, n_width, n_height
|
||||
|
||||
@@ -71,14 +87,18 @@ def img2img_inf(
|
||||
use_stencil: str,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -94,10 +114,6 @@ def img2img_inf(
|
||||
image = init_image.convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -112,6 +128,10 @@ def img2img_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
@@ -144,7 +164,7 @@ def img2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
if (
|
||||
@@ -153,6 +173,7 @@ def img2img_inf(
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -167,8 +188,8 @@ def img2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
args.use_tuned = False
|
||||
@@ -189,6 +210,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -208,10 +230,11 @@ def img2img_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -219,6 +242,7 @@ def img2img_inf(
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -239,26 +263,23 @@ def img2img_inf(
|
||||
cpu_scheduling,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -310,6 +331,7 @@ if __name__ == "__main__":
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
else:
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
@@ -327,6 +349,7 @@ if __name__ == "__main__":
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -362,3 +385,7 @@ if __name__ == "__main__":
|
||||
extra_info = {"STRENGTH": args.strength}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -10,10 +11,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -42,14 +42,18 @@ def inpaint_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -60,10 +64,6 @@ def inpaint_inf(
|
||||
args.mask_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -78,6 +78,10 @@ def inpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
@@ -93,7 +97,7 @@ def inpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -103,6 +107,7 @@ def inpaint_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -117,14 +122,15 @@ def inpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
@@ -132,13 +138,13 @@ def inpaint_inf(
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -147,6 +153,7 @@ def inpaint_inf(
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -168,26 +175,23 @@ def inpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -217,6 +221,7 @@ if __name__ == "__main__":
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
@@ -224,9 +229,9 @@ if __name__ == "__main__":
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
@@ -269,3 +274,7 @@ if __name__ == "__main__":
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
19
apps/stable_diffusion/scripts/main.py
Normal file
19
apps/stable_diffusion/scripts/main.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.scripts import (
|
||||
img2img,
|
||||
txt2img,
|
||||
# inpaint,
|
||||
# outpaint,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.app == "txt2img":
|
||||
txt2img.main()
|
||||
elif args.app == "img2img":
|
||||
img2img.main()
|
||||
# elif args.app == "inpaint":
|
||||
# inpaint.main()
|
||||
# elif args.app == "outpaint":
|
||||
# outpaint.main()
|
||||
else:
|
||||
print(f"args.app value is {args.app} but this isn't supported")
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
@@ -10,10 +11,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -45,14 +45,18 @@ def outpaint_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -62,10 +66,6 @@ def outpaint_inf(
|
||||
args.img_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -80,6 +80,10 @@ def outpaint_inf(
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
@@ -95,7 +99,7 @@ def outpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -105,6 +109,7 @@ def outpaint_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -119,8 +124,8 @@ def outpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
@@ -135,10 +140,11 @@ def outpaint_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -151,6 +157,7 @@ def outpaint_inf(
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -177,26 +184,23 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -230,6 +234,7 @@ if __name__ == "__main__":
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
@@ -294,3 +299,7 @@ if __name__ == "__main__":
|
||||
}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
674
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
674
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
@@ -0,0 +1,674 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# HuggingFace Token
|
||||
# YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
# Import required libraries
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
import torch_mlir
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
clear_all,
|
||||
)
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class LoraDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
set="train",
|
||||
prompt="myloraprompt",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.prompt = prompt
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
def lora_train(
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
training_images_dir: str,
|
||||
lora_save_dir: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
print(
|
||||
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
||||
)
|
||||
print(
|
||||
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.steps = steps
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.training_images_dir = training_images_dir
|
||||
args.lora_save_dir = lora_save_dir
|
||||
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
device_str = device.split("=>", 1)[1].strip().split("://")
|
||||
if len(device_str) > 1:
|
||||
device_str = device_str[0] + ":" + device_str[1]
|
||||
else:
|
||||
device_str = device_str[0]
|
||||
args.device = device_str
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="unet"
|
||||
)
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze everything but LoRA
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
freeze_params(text_encoder.parameters())
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
text_encoder.to(args.device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None
|
||||
if name.endswith("attn1.processor")
|
||||
else unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[
|
||||
block_id
|
||||
]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = LoraDataset(
|
||||
data_root=args.training_images_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
prompt=args.prompts[0],
|
||||
repeats=100,
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
args.hf_model_id, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": steps,
|
||||
"train_batch_size": batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
# prediction = torch.from_numpy(res)
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
lora_layers.parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {train_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params__ = [
|
||||
i for i in text_encoder.get_input_embeddings().parameters()
|
||||
]
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
training_function()
|
||||
|
||||
# Save the lora weights
|
||||
unet.save_attn_procs(args.lora_save_dir)
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
if len(args.prompts) != 1:
|
||||
print("Need exactly one prompt for the LoRA word")
|
||||
lora_train(
|
||||
args.prompts[0],
|
||||
args.height,
|
||||
args.width,
|
||||
args.training_steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.batch_count,
|
||||
args.batch_size,
|
||||
args.scheduler,
|
||||
"None",
|
||||
args.hf_model_id,
|
||||
args.precision,
|
||||
args.device,
|
||||
args.max_length,
|
||||
args.training_images_dir,
|
||||
args.lora_save_dir,
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -9,10 +10,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -43,11 +43,13 @@ def txt2img_inf(
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -56,10 +58,6 @@ def txt2img_inf(
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -77,14 +75,9 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_lora = ""
|
||||
if lora_weights == "None" and not lora_hf_id:
|
||||
use_lora = ""
|
||||
elif not lora_hf_id:
|
||||
use_lora = lora_weights
|
||||
else:
|
||||
use_lora = lora_hf_id
|
||||
args.use_lora = use_lora
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -98,7 +91,7 @@ def txt2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -108,6 +101,7 @@ def txt2img_inf(
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -123,8 +117,8 @@ def txt2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
@@ -141,17 +135,18 @@ def txt2img_inf(
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
@@ -169,28 +164,23 @@ def txt2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
)
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
# text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
@@ -200,7 +190,6 @@ if __name__ == "__main__":
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
use_lora = args.use_lora
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
@@ -216,7 +205,8 @@ if __name__ == "__main__":
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=use_lora,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
@@ -256,3 +246,7 @@ if __name__ == "__main__":
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
@@ -12,8 +13,6 @@ from apps.stable_diffusion.src import (
|
||||
)
|
||||
|
||||
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -41,33 +40,28 @@ def upscaler_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.height = height
|
||||
args.width = width
|
||||
|
||||
if init_image is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB").resize((args.height, args.width))
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -85,8 +79,14 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
args.width = 128
|
||||
new_config_obj = Config(
|
||||
"upscaler",
|
||||
args.hf_model_id,
|
||||
@@ -94,10 +94,10 @@ def upscaler_inf(
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=None,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
@@ -118,8 +118,8 @@ def upscaler_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
@@ -135,11 +135,14 @@ def upscaler_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(schedulers[scheduler])
|
||||
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
|
||||
"DDPM"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
@@ -150,24 +153,32 @@ def upscaler_inf(
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
for i in range(0, width, 128):
|
||||
for j in range(0, height, 128):
|
||||
box = (j, i, j + 128, i + 128)
|
||||
upscaled_image = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
low_res_img.crop(box),
|
||||
batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
@@ -224,6 +235,7 @@ if __name__ == "__main__":
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
|
||||
@@ -22,6 +22,7 @@ datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
@@ -42,7 +43,7 @@ hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/txt2img.py'],
|
||||
['scripts/main.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
|
||||
@@ -18,6 +18,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,7 +100,8 @@ class SharkifyStableDiffusionModel:
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
use_lora: str = ""
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -107,6 +109,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -200,6 +203,7 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
@@ -255,6 +259,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
@@ -281,13 +286,14 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_controlled_unet(self):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -295,6 +301,8 @@ class SharkifyStableDiffusionModel:
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -333,6 +341,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_controlled_unet
|
||||
|
||||
@@ -386,6 +395,7 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_cnet
|
||||
|
||||
@@ -399,7 +409,7 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
self.unet.load_attn_procs(use_lora)
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
@@ -444,6 +454,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
@@ -481,18 +492,21 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -512,6 +526,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
@@ -539,6 +554,7 @@ class SharkifyStableDiffusionModel:
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
|
||||
self.base_model_id = base_model_id
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
@@ -556,7 +572,12 @@ class SharkifyStableDiffusionModel:
|
||||
compiled_controlnet = self.get_control_net()
|
||||
compiled_controlled_unet = self.get_controlled_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
if self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
compiled_unet = get_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
|
||||
@@ -20,6 +20,15 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
bucket_key = "gs://shark_tank/prashant_nod"
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
@@ -41,6 +50,12 @@ def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
|
||||
# TODO: Get the quantize model from model_db.json
|
||||
if args.use_quantize == "int8":
|
||||
bk, mk, flags = get_quantize_model()
|
||||
return get_shark_model(bk, mk, flags)
|
||||
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
|
||||
@@ -31,6 +31,9 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
|
||||
|
||||
class StableDiffusionPipeline:
|
||||
def __init__(
|
||||
@@ -58,6 +61,7 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -226,6 +230,7 @@ class StableDiffusionPipeline:
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
@@ -275,6 +280,9 @@ class StableDiffusionPipeline:
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -313,13 +321,18 @@ class StableDiffusionPipeline:
|
||||
use_stencil: str = None,
|
||||
use_lora: str = "",
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
is_inpaint = cls.__name__ in [
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
if import_mlir:
|
||||
if import_mlir or use_lora:
|
||||
if not import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
|
||||
)
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
@@ -337,6 +350,7 @@ class StableDiffusionPipeline:
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
|
||||
@@ -32,4 +32,6 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_extended_name,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
@@ -76,18 +76,19 @@ def load_winograd_configs():
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
def load_lower_configs(base_model_id=None):
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
if not base_model_id:
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
@@ -114,7 +115,14 @@ def load_lower_configs():
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
and args.width == 768
|
||||
):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
|
||||
@@ -22,6 +22,12 @@ p = argparse.ArgumentParser(
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
@@ -109,6 +115,31 @@ p.add_argument(
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The no. of steps to train",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
@@ -315,6 +346,14 @@ p.add_argument(
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
|
||||
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -335,7 +374,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
default="2073741824",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -21,7 +23,7 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
@@ -95,6 +97,7 @@ def compile_through_fx(
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -116,19 +119,21 @@ def compile_through_fx(
|
||||
if use_tuned:
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, model_name, base_model_id
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
|
||||
if generate_vmfb:
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
@@ -264,8 +269,9 @@ def set_init_device_flags():
|
||||
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
or args.height not in [512, 768]
|
||||
or (args.height == 512 and args.width != 512)
|
||||
or (args.height == 768 and args.width != 768)
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
@@ -299,6 +305,20 @@ def set_init_device_flags():
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height == 768
|
||||
and args.width == 768
|
||||
and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
else:
|
||||
@@ -368,7 +388,7 @@ def get_available_devices():
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
available_devices.append("device => cpu")
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -454,7 +474,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
@@ -464,6 +484,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
@@ -629,3 +758,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
|
||||
|
||||
return text_output
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
@@ -62,10 +63,11 @@ from apps.stable_diffusion.web.ui import (
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj.init()
|
||||
global_obj._init()
|
||||
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
@@ -94,6 +96,10 @@ with gr.Blocks(
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
lora_train_web.render()
|
||||
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
|
||||
@@ -38,3 +38,4 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
|
||||
@@ -7,93 +7,100 @@ Procedure to upgrade the dark theme:
|
||||
*/
|
||||
|
||||
:root {
|
||||
--color-accent-soft: var(--neutral-700);
|
||||
--color-background-primary: var(--neutral-950);
|
||||
--color-background-secondary: var(--neutral-900);
|
||||
--color-border-accent: var(--neutral-600);
|
||||
--color-border-primary: var(--neutral-700);
|
||||
--text-color-code-background: var(--neutral-800);
|
||||
--text-color-link-active: var(--secondary-500);
|
||||
--text-color-link: var(--secondary-500);
|
||||
--text-color-link-hover: var(--secondary-400);
|
||||
--text-color-link-visited: var(--secondary-600);
|
||||
--text-color-subdued: var(--neutral-400);
|
||||
--body-background-color: var(--color-background-primary);
|
||||
--body-background-fill: var(--background-fill-primary);
|
||||
--body-text-color: var(--neutral-100);
|
||||
--color-accent-soft: var(--neutral-700);
|
||||
--background-fill-primary: var(--neutral-950);
|
||||
--background-fill-secondary: var(--neutral-900);
|
||||
--border-color-accent: var(--neutral-600);
|
||||
--border-color-primary: var(--neutral-700);
|
||||
--link-text-color-active: var(--secondary-500);
|
||||
--link-text-color: var(--secondary-500);
|
||||
--link-text-color-hover: var(--secondary-400);
|
||||
--link-text-color-visited: var(--secondary-600);
|
||||
--body-text-color-subdued: var(--neutral-400);
|
||||
--shadow-spread: 1px;
|
||||
--block-background: var(--neutral-800);
|
||||
--block-border-color: var(--color-border-primary);
|
||||
--block-border-width: 1px;
|
||||
--block-info-color: var(--text-color-subdued);
|
||||
--block-label-background: var(--color-background-secondary);
|
||||
--block-label-border-color: var(--color-border-primary);
|
||||
--block-label-border-width: 1px;
|
||||
--block-label-color: var(--neutral-200);
|
||||
--block-shadow: none;
|
||||
--block-title-background: none;
|
||||
--block-title-border-color: none;
|
||||
--block-title-border-width: 0px;
|
||||
--block-title-color: var(--neutral-200);
|
||||
--panel-background: var(--color-background-secondary);
|
||||
--panel-border-color: var(--color-border-primary);
|
||||
--checkbox-background: var(--neutral-800);
|
||||
--checkbox-background-focus: var(--checkbox-background);
|
||||
--checkbox-background-hover: var(--checkbox-background);
|
||||
--checkbox-background-selected: var(--secondary-600);
|
||||
--block-background-fill: var(--neutral-800);
|
||||
--block-border-color: var(--border-color-primary);
|
||||
--block_border_width: None;
|
||||
--block-info-text-color: var(--body-text-color-subdued);
|
||||
--block-label-background-fill: var(--background-fill-secondary);
|
||||
--block-label-border-color: var(--border-color-primary);
|
||||
--block_label_border_width: None;
|
||||
--block-label-text-color: var(--neutral-200);
|
||||
--block_shadow: None;
|
||||
--block_title_background_fill: None;
|
||||
--block_title_border_color: None;
|
||||
--block_title_border_width: None;
|
||||
--block-title-text-color: var(--neutral-200);
|
||||
--panel-background-fill: var(--background-fill-secondary);
|
||||
--panel-border-color: var(--border-color-primary);
|
||||
--panel_border_width: None;
|
||||
--checkbox-background-color: var(--neutral-800);
|
||||
--checkbox-background-color-focus: var(--checkbox-background-color);
|
||||
--checkbox-background-color-hover: var(--checkbox-background-color);
|
||||
--checkbox-background-color-selected: var(--secondary-600);
|
||||
--checkbox-border-color: var(--neutral-700);
|
||||
--checkbox-border-color-focus: var(--secondary-500);
|
||||
--checkbox-border-color-hover: var(--neutral-600);
|
||||
--checkbox-border-color-selected: var(--secondary-600);
|
||||
--checkbox-label-background: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-selected: var(--checkbox-label-background);
|
||||
--checkbox-label-border-color: var(--color-border-primary);
|
||||
--checkbox-label-border-color-hover: var(--color-border-primary);
|
||||
--checkbox-text-color: var(--body-text-color);
|
||||
--checkbox-text-color-selected: var(--checkbox-text-color);
|
||||
--error-background: var(--color-background-primary);
|
||||
--error-border-color: var(--color-border-primary);
|
||||
--error-border-width: var(--error-border-width);
|
||||
--error-color: #ef4444;
|
||||
--input-background: var(--neutral-800);
|
||||
--input-background-focus: var(--secondary-600);
|
||||
--input-background-hover: var(--input-background);
|
||||
--input-border-color: var(--color-border-primary);
|
||||
--checkbox-border-width: var(--input-border-width);
|
||||
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
|
||||
--checkbox-label-border-color: var(--border-color-primary);
|
||||
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
|
||||
--checkbox-label-border-width: var(--input-border-width);
|
||||
--checkbox-label-text-color: var(--body-text-color);
|
||||
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
|
||||
--error-background-fill: var(--background-fill-primary);
|
||||
--error-border-color: var(--border-color-primary);
|
||||
--error_border_width: None;
|
||||
--error-text-color: #ef4444;
|
||||
--input-background-fill: var(--neutral-800);
|
||||
--input-background-fill-focus: var(--secondary-600);
|
||||
--input-background-fill-hover: var(--input-background-fill);
|
||||
--input-border-color: var(--border-color-primary);
|
||||
--input-border-color-focus: var(--neutral-700);
|
||||
--input-border-color-hover: var(--color-border-primary);
|
||||
--input-border-color-hover: var(--input-border-color);
|
||||
--input_border_width: None;
|
||||
--input-placeholder-color: var(--neutral-500);
|
||||
--input-shadow: var(--input-shadow);
|
||||
--input_shadow: None;
|
||||
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
|
||||
--loader-color: var(--color-accent);
|
||||
--stat-color-background: linear-gradient(to right, var(--primary-400), var(--primary-600));
|
||||
--loader_color: None;
|
||||
--slider_color: None;
|
||||
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
|
||||
--table-border-color: var(--neutral-700);
|
||||
--table-even-background: var(--neutral-950);
|
||||
--table-odd-background: var(--neutral-900);
|
||||
--table-even-background-fill: var(--neutral-950);
|
||||
--table-odd-background-fill: var(--neutral-900);
|
||||
--table-row-focus: var(--color-accent-soft);
|
||||
--button-cancel-background: linear-gradient(to bottom right, #dc2626, #b91c1c);
|
||||
--button-cancel-background-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
|
||||
--button-border-width: var(--input-border-width);
|
||||
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
|
||||
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
|
||||
--button-cancel-border-color: #dc2626;
|
||||
--button-cancel-border-color-hover: var(--button-cancel-border-color);
|
||||
--button-cancel-text-color: white;
|
||||
--button-cancel-text-color-hover: var(--button-cancel-text-color);
|
||||
--button-primary-background: linear-gradient(to bottom right, var(--primary-600), var(--primary-700));
|
||||
--button-primary-background-hover: linear-gradient(to bottom right, var(--primary-600), var(--primary-600));
|
||||
--button-primary-border-color: var(--primary-600);
|
||||
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
|
||||
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
|
||||
--button-primary-border-color: var(--primary-500);
|
||||
--button-primary-border-color-hover: var(--button-primary-border-color);
|
||||
--button-primary-text-color: white;
|
||||
--button-primary-text-color-hover: var(--button-primary-text-color);
|
||||
--button-secondary-background: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
|
||||
--button-secondary-background-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
|
||||
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
|
||||
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
|
||||
--button-secondary-border-color: var(--neutral-600);
|
||||
--button-secondary-border-color-hover: var(--button-secondary-border-color);
|
||||
--button-secondary-text-color: white;
|
||||
--button-secondary-text-color-hover: var(--button-secondary-text-color);
|
||||
--block-border-width: 1px;
|
||||
--block-label-border-width: 1px;
|
||||
--form-gap-width: 1px;
|
||||
--error-border-width: 1px;
|
||||
--input-border-width: 1px;
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
@@ -131,7 +138,7 @@ body {
|
||||
}
|
||||
|
||||
#prompt_box textarea, #negative_prompt_box textarea {
|
||||
background-color: var(--color-background-primary) !important;
|
||||
background-color: var(--background-fill-primary) !important;
|
||||
}
|
||||
|
||||
#prompt_examples {
|
||||
@@ -143,7 +150,6 @@ body {
|
||||
}
|
||||
|
||||
#ui_body {
|
||||
background-color: var(--color-background-secondary) !important;
|
||||
padding: var(--size-2) !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
@@ -172,6 +178,7 @@ footer {
|
||||
|
||||
/* Hide "remove buttons" from ui dropdowns */
|
||||
#custom_model .token-remove.remove-all,
|
||||
#lora_weights .token-remove.remove-all,
|
||||
#scheduler .token-remove.remove-all,
|
||||
#device .token-remove.remove-all,
|
||||
#stencil_model .token-remove.remove-all {
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +74,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -129,21 +145,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -153,6 +171,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -174,8 +193,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -228,6 +245,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
use_stencil,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
@@ -236,6 +255,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +69,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
type="pil",
|
||||
).style(height=350)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -131,21 +147,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -155,6 +173,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -176,8 +195,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -230,6 +247,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
@@ -238,6 +257,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
205
apps/stable_diffusion/web/ui/lora_train_ui.py
Normal file
205
apps/stable_diffusion/web/ui/lora_train_ui.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import lora_train
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="image_dir_box_outer"):
|
||||
training_images_dir = gr.Textbox(
|
||||
label="ImageDirectory",
|
||||
value=args.training_images_dir,
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list_txt2img,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 768, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
2000,
|
||||
value=args.training_steps,
|
||||
step=1,
|
||||
label="Training Steps",
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
train_lora = gr.Button("Train LoRA")
|
||||
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
lora_save_dir = (
|
||||
args.lora_save_dir if args.lora_save_dir else Path.cwd()
|
||||
)
|
||||
lora_save_dir = Path(lora_save_dir, "lora")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Lora at",
|
||||
value=lora_save_dir,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=lora_train,
|
||||
inputs=[
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
training_images_dir,
|
||||
output_loc,
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
train_click = train_lora.click(**kwargs)
|
||||
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,6 +66,21 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -150,21 +166,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
1, 100, value=20, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -174,6 +192,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -195,8 +214,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -250,6 +267,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
@@ -258,6 +277,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_files,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
@@ -69,15 +70,13 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="Lora based inference option", open=False
|
||||
):
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path()})",
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files(),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
@@ -142,22 +141,25 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -179,8 +181,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
@@ -249,8 +249,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.utils.png_metadata import (
|
||||
|
||||
@@ -65,6 +65,21 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -88,18 +103,16 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
height = gr.Slider(
|
||||
128,
|
||||
512,
|
||||
value=128,
|
||||
value=args.height,
|
||||
step=128,
|
||||
label="Height",
|
||||
interactive=False,
|
||||
)
|
||||
width = gr.Slider(
|
||||
128,
|
||||
512,
|
||||
value=128,
|
||||
value=args.width,
|
||||
step=128,
|
||||
label="Width",
|
||||
interactive=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -109,7 +122,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
interactive=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
@@ -132,21 +144,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Noise Level",
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -156,6 +170,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
@@ -177,8 +192,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=1, min_width=150):
|
||||
clear_queue = gr.Button("Clear Queue")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -228,6 +241,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
@@ -236,6 +251,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
clear_queue.click(
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
)
|
||||
|
||||
@@ -5,6 +5,10 @@ import glob
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,24 +74,57 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_custom_model_path():
|
||||
return Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd(), "models")
|
||||
def get_custom_model_path(model="models"):
|
||||
prefix = Path(args.ckpt_dir) if args.ckpt_dir else Path(Path.cwd())
|
||||
match model:
|
||||
case "models":
|
||||
return Path(prefix, "models")
|
||||
case "vae":
|
||||
return Path(prefix, "models/vae")
|
||||
case "lora":
|
||||
return Path(prefix, "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name):
|
||||
return os.path.join(get_custom_model_path(), custom_model_name)
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files():
|
||||
def get_custom_model_files(model="models"):
|
||||
ckpt_files = []
|
||||
for extn in custom_model_filetypes:
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
file_types = custom_model_filetypes + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_custom_model_path(), extn))
|
||||
for x in glob.glob(
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
return use_weight
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
global_obj.set_sd_status(SD_STATE_CANCEL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
@@ -8,39 +8,64 @@ Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
|
||||
|
||||
def init():
|
||||
global sd_obj
|
||||
global config_obj
|
||||
sd_obj = None
|
||||
config_obj = None
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
|
||||
def set_sd_obj(value):
|
||||
global sd_obj
|
||||
sd_obj = value
|
||||
global _sd_obj
|
||||
_sd_obj = value
|
||||
|
||||
|
||||
def set_sd_scheduler(key):
|
||||
global _sd_obj
|
||||
_sd_obj.scheduler = _schedulers[key]
|
||||
|
||||
|
||||
def set_sd_status(value):
|
||||
global _sd_obj
|
||||
_sd_obj.status = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global config_obj
|
||||
config_obj = value
|
||||
global _config_obj
|
||||
_config_obj = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global sd_obj
|
||||
sd_obj.scheduler = value
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
|
||||
|
||||
def get_sd_obj():
|
||||
return sd_obj
|
||||
return _sd_obj
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
return _sd_obj.status
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
return config_obj
|
||||
return _config_obj
|
||||
|
||||
|
||||
def get_scheduler(key):
|
||||
return _schedulers[key]
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global sd_obj
|
||||
global config_obj
|
||||
del sd_obj
|
||||
del config_obj
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
del _sd_obj
|
||||
del _config_obj
|
||||
del _schedulers
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]
|
||||
counter = 0
|
||||
for import_opt in import_options:
|
||||
for model_name in hf_model_names:
|
||||
if model_name in to_skip:
|
||||
continue
|
||||
for use_tune in tuned_options:
|
||||
if (
|
||||
model_name == "stabilityai/stable-diffusion-2-1"
|
||||
and use_tune == tuned_options[0]
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
model_name == "stabilityai/stable-diffusion-2-1-base"
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
@@ -174,9 +185,23 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
if "2_1_base" in model_name:
|
||||
with open(dumpfile_name, "r+") as f:
|
||||
output = f.readlines()
|
||||
print("\n".join(output))
|
||||
if model_name == "CompVis/stable-diffusion-v1-4":
|
||||
print("failed a known successful model.")
|
||||
exit(1)
|
||||
if os.name == "nt":
|
||||
counter += 1
|
||||
if counter % 2 == 0:
|
||||
extra_flags.append(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
else:
|
||||
if counter != 1:
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
|
||||
@@ -70,3 +70,9 @@ def pytest_addoption(parser):
|
||||
default="./temp_dispatch_benchmarks",
|
||||
help="Directory in which dispatch benchmarks are saved.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--batchsize",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Batch size for the tested model.",
|
||||
)
|
||||
|
||||
@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
|
||||
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "Pix2Pix" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "Pix2Pix" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
pix2pix_init = Path(
|
||||
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
for line in fileinput.input(pix2pix_init, inplace=True):
|
||||
if "StableDiffusionPix2PixZeroPipeline" in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
if path_to_tranformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
--pre
|
||||
|
||||
numpy>1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
torchvision==0.16.0.dev20230322
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
@@ -15,8 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras>=2.10
|
||||
tensorflow>2.11
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
@@ -33,6 +33,7 @@ lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@main
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
@@ -25,6 +25,7 @@ omegaconf
|
||||
safetensors
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -45,7 +45,7 @@ if ($arguments -eq "--force"){
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
break
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,12 +78,12 @@ if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is u
|
||||
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
|
||||
{
|
||||
Write-Host "Please install Python 3.11 and try again"
|
||||
break
|
||||
exit 34
|
||||
}
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
echo "Successfully Installed torch + cu118."
|
||||
else
|
||||
echo "Could not install torch + cu117." >&2
|
||||
echo "Could not install torch + cu118." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@@ -35,8 +35,9 @@ def run_cmd(cmd, debug=False):
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
result_str = result.stdout.decode()
|
||||
return result_str
|
||||
stdout = result.stdout.decode()
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
@@ -90,6 +90,7 @@ def build_benchmark_args(
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
# if time_extractor:
|
||||
# benchmark_cl.append(time_extractor)
|
||||
benchmark_cl.append(f"--print_statistics=true")
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds.
|
||||
Run benchmark command, extract result and return iteration/seconds, host
|
||||
peak memory, and device peak memory.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * 0.001)
|
||||
iter_per_second = 1.0 / (time_ms * 0.001)
|
||||
|
||||
# Extract peak memory.
|
||||
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
|
||||
host_peak_b = int(host_regex.search(bench_stderr).group(1))
|
||||
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
|
||||
device_peak_b = int(device_regex.search(bench_stderr).group(1))
|
||||
return iter_per_second, host_peak_b, device_peak_b
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg"]:
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
return [
|
||||
@@ -188,21 +188,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
benchmark_data = run_benchmark_module(benchmark_cl)
|
||||
iter_per_second, _, _ = run_benchmark_module(
|
||||
benchmark_cl
|
||||
)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(benchmark_data) + "\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (benchmark_data * 0.001))
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
|
||||
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
|
||||
@@ -30,11 +30,10 @@ def get_iree_gpu_args():
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
return []
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
|
||||
@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
|
||||
@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
|
||||
from shark.parser import shark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
import csv
|
||||
import os
|
||||
|
||||
TF_CPU_DEVICE = "/CPU:0"
|
||||
TF_GPU_DEVICE = "/GPU:0"
|
||||
|
||||
|
||||
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
|
||||
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
@@ -70,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
@@ -104,40 +113,56 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
print(
|
||||
"Currently disabled TensorFloat32 calculations in pytorch benchmarks."
|
||||
)
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname)[:2]
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
# frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor")
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
# try:
|
||||
# frontend_model = torch.compile(
|
||||
# frontend_model, mode="max-autotune", backend="inductor"
|
||||
# )
|
||||
# except RuntimeError:
|
||||
# frontend_model = HFmodel.model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
else:
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by PyTorch.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
@@ -155,38 +180,55 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
|
||||
tf_device = "/CPU:0"
|
||||
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
|
||||
tf_device = TF_CPU_DEVICE
|
||||
with tf.device(tf_device):
|
||||
(
|
||||
model,
|
||||
input,
|
||||
) = get_tf_model(
|
||||
modelname
|
||||
modelname, self.import_args
|
||||
)[:2]
|
||||
frontend_model = model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
tf.config.experimental.reset_memory_stats(tf_device)
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
memory_info = tf.config.experimental.get_memory_info(tf_device)
|
||||
device_peak_b = memory_info["peak"]
|
||||
else:
|
||||
# tf.config.experimental does not currently support measuring
|
||||
# CPU memory usage.
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
|
||||
self.benchmark_cl
|
||||
)
|
||||
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
return [
|
||||
f"{iter_per_second}",
|
||||
f"{1000/iter_per_second}",
|
||||
_bytes_to_mb_str(host_peak_b),
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
@@ -196,8 +238,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.run("forward", input_list)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -306,11 +347,19 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
return comp_str
|
||||
|
||||
def benchmark_all_csv(
|
||||
self, inputs: tuple, modelname, dynamic, device_str, frontend
|
||||
self,
|
||||
inputs: tuple,
|
||||
modelname,
|
||||
dynamic,
|
||||
device_str,
|
||||
frontend,
|
||||
import_args,
|
||||
):
|
||||
self.setup_cl(inputs)
|
||||
self.import_args = import_args
|
||||
field_names = [
|
||||
"model",
|
||||
"batch_size",
|
||||
"engine",
|
||||
"dialect",
|
||||
"device",
|
||||
@@ -324,7 +373,12 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
"tags",
|
||||
"notes",
|
||||
"datetime",
|
||||
"host_memory_mb",
|
||||
"device_memory_mb",
|
||||
"measured_host_memory_mb",
|
||||
"measured_device_memory_mb",
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
if shark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
@@ -336,75 +390,77 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
|
||||
with open("bench_results.csv", mode="a", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_result = {}
|
||||
bench_result["model"] = modelname
|
||||
bench_info = {}
|
||||
bench_info["model"] = modelname
|
||||
bench_info["batch_size"] = str(import_args["batch_size"])
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = shark_args.num_iterations
|
||||
if dynamic == True:
|
||||
bench_result["shape_type"] = "dynamic"
|
||||
bench_info["shape_type"] = "dynamic"
|
||||
else:
|
||||
bench_result["shape_type"] = "static"
|
||||
bench_result["device"] = device_str
|
||||
bench_info["shape_type"] = "static"
|
||||
bench_info["device"] = device_str
|
||||
if "fp16" in modelname:
|
||||
bench_result["data_type"] = "float16"
|
||||
bench_info["data_type"] = "float16"
|
||||
else:
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
bench_info["data_type"] = inputs[0].dtype
|
||||
|
||||
for e in engines:
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = ["", "", ""]
|
||||
engine_result = {}
|
||||
if e == "frontend":
|
||||
bench_result["engine"] = frontend
|
||||
engine_result["engine"] = frontend
|
||||
if check_requirements(frontend):
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "baseline"
|
||||
self.frontend_result = engine_result["ms/iter"]
|
||||
engine_result["vs. PyTorch/TF"] = "baseline"
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
engine_result["param_count"],
|
||||
engine_result["tags"],
|
||||
engine_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
else:
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "shark_python":
|
||||
bench_result["engine"] = "shark_python"
|
||||
engine_result["engine"] = "shark_python"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_python(inputs)
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "shark_iree_c":
|
||||
bench_result["engine"] = "shark_iree_c"
|
||||
engine_result["engine"] = "shark_iree_c"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_c()
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "onnxruntime":
|
||||
bench_result["engine"] = "onnxruntime"
|
||||
engine_result["engine"] = "onnxruntime"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_onnx(modelname, inputs)
|
||||
|
||||
bench_result["dialect"] = self.mlir_dialect
|
||||
bench_result["iterations"] = shark_args.num_iterations
|
||||
bench_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_result)
|
||||
engine_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_info | engine_result)
|
||||
|
||||
@@ -139,11 +139,21 @@ def download_model(
|
||||
tank_url="gs://shark_tank/latest",
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
import_args={"batch_size": "1"},
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
if import_args["batch_size"] != 1:
|
||||
model_dir_name = (
|
||||
model_name
|
||||
+ "_"
|
||||
+ frontend
|
||||
+ "_BS"
|
||||
+ str(import_args["batch_size"])
|
||||
)
|
||||
else:
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
@@ -194,9 +204,17 @@ def download_model(
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
|
||||
if not os.path.exists(filename):
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
print(
|
||||
"The model data was not found. Trying to generate artifacts locally."
|
||||
)
|
||||
gen_shark_files(model_name, frontend, WORKDIR, import_args)
|
||||
|
||||
assert os.path.exists(filename), f"MLIR not found at {filename}"
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
@@ -297,6 +297,7 @@ def transform_fx(fx_g):
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
|
||||
@@ -22,7 +22,7 @@ bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-large-uncased,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
|
||||
@@ -35,3 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,True,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -146,7 +146,6 @@ if __name__ == "__main__":
|
||||
backend_config = "cuda"
|
||||
args = [
|
||||
"--iree-cuda-llvm-target-arch=sm_80",
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
"--iree-enable-fusion-with-reduction-ops",
|
||||
]
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -33,9 +33,10 @@ def create_hash(file_name):
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
@@ -58,8 +59,7 @@ def save_torch_model(torch_model_list):
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.use_tuned = False
|
||||
args.local_tank_cache = WORKDIR
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
@@ -74,24 +74,41 @@ def save_torch_model(torch_model_list):
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=WORKDIR,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
model, input, _ = get_vision_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
model, input, _ = get_hf_model(torch_model_name, import_args)
|
||||
elif model_type == "hf_seq2seq":
|
||||
model, input, _ = get_hf_seq2seq_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
model, input, _ = get_fp16_model(torch_model_name, import_args)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
)
|
||||
if import_args["batch_size"] != 1:
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(torch_model_name)
|
||||
+ "_torch"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
@@ -115,12 +132,14 @@ def save_torch_model(torch_model_list):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -145,16 +164,38 @@ def save_tf_model(tf_model_list):
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
if model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
if model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
model, input, _ = get_masked_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name, import_args)
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name, import_args)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
if import_args["batch_size"] != 1:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(tf_model_name)
|
||||
+ "_tf"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
@@ -166,13 +207,9 @@ def save_tf_model(tf_model_list):
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
@@ -184,18 +221,18 @@ def save_tflite_model(tflite_model_list):
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
WORKDIR, str(tflite_model_name) + "_tflite"
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
# Preprocess to get SharkImporter input import_args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input args
|
||||
# Use SharkImporter to get SharkInference input import_args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
@@ -219,6 +256,71 @@ def save_tflite_model(tflite_model_list):
|
||||
)
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
has_pkgs = False
|
||||
if frontend == "torch":
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class NoImportException(Exception):
|
||||
"Raised when requirements are not met for OTF model artifact generation."
|
||||
pass
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir, importer_args):
|
||||
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
||||
# TODO: Add TFlite support.
|
||||
import tempfile
|
||||
|
||||
import_args = importer_args
|
||||
if check_requirements(frontend):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir, import_args)
|
||||
|
||||
elif frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir, import_args)
|
||||
else:
|
||||
raise NoImportException
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
@@ -228,7 +330,7 @@ def is_valid_file(arg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
||||
# Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument(
|
||||
# "--torch_model_csv",
|
||||
@@ -256,23 +358,26 @@ if __name__ == "__main__":
|
||||
# )
|
||||
# parser.add_argument("--upload", type=bool, default=False)
|
||||
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
# old_import_args = parser.parse_import_args()
|
||||
import_args = {
|
||||
"batch_size": "1",
|
||||
}
|
||||
print(import_args)
|
||||
home = str(Path.home())
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tf_model_list.csv"
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "tank", "torch_sd_list.csv")
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
import_args,
|
||||
)
|
||||
save_torch_model(torch_model_csv)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
@@ -31,4 +31,12 @@ xlm-roberta-base,False,False,-,-,-
|
||||
facebook/convnext-tiny-224,False,False,-,-,-
|
||||
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
|
||||
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
|
||||
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -7,6 +6,8 @@ import sys
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
||||
vision_models = [
|
||||
"alexnet",
|
||||
"resnet101",
|
||||
@@ -17,6 +18,8 @@ vision_models = [
|
||||
"wide_resnet50_2",
|
||||
"mobilenet_v3_small",
|
||||
"mnasnet1_0",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
]
|
||||
hf_img_cls_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
@@ -25,17 +28,23 @@ hf_img_cls_models = [
|
||||
"microsoft/beit-base-patch16-224-pt22k-ft22k",
|
||||
"nvidia/mit-b0",
|
||||
]
|
||||
hf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
|
||||
|
||||
def get_torch_model(modelname):
|
||||
def get_torch_model(modelname, import_args):
|
||||
if modelname in vision_models:
|
||||
return get_vision_model(modelname)
|
||||
return get_vision_model(modelname, import_args)
|
||||
elif modelname in hf_img_cls_models:
|
||||
return get_hf_img_cls_model(modelname)
|
||||
return get_hf_img_cls_model(modelname, import_args)
|
||||
elif modelname in hf_seq2seq_models:
|
||||
return get_hf_seq2seq_model(modelname, import_args)
|
||||
elif "fp16" in modelname:
|
||||
return get_fp16_model(modelname)
|
||||
return get_fp16_model(modelname, import_args)
|
||||
else:
|
||||
return get_hf_model(modelname)
|
||||
return get_hf_model(modelname, import_args)
|
||||
|
||||
|
||||
##################### Hugging Face Image Classification Models ###################################
|
||||
@@ -78,13 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module):
|
||||
return self.model.forward(inputs)[0]
|
||||
|
||||
|
||||
def get_hf_img_cls_model(name):
|
||||
def get_hf_img_cls_model(name, import_args):
|
||||
model = HuggingFaceImageClassification(name)
|
||||
# you can use preprocess_input_image to get the test_input or just random value.
|
||||
test_input = preprocess_input_image(name)
|
||||
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
|
||||
# print("test_input.shape: ", test_input.shape)
|
||||
# test_input.shape: torch.Size([1, 3, 224, 224])
|
||||
test_input = test_input.repeat(int(import_args["batch_size"]), 1, 1, 1)
|
||||
actual_out = model(test_input)
|
||||
# print("actual_out.shape: ", actual_out.shape)
|
||||
# actual_out.shape: torch.Size([1, 1000])
|
||||
@@ -114,18 +124,58 @@ class HuggingFaceLanguage(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_model(name):
|
||||
def get_hf_model(name, import_args):
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
)
|
||||
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
test_input = torch.randint(2, (int(import_args["batch_size"]), 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
|
||||
class HFSeq2SeqLanguageModel(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
from transformers import AutoTokenizer, T5Model
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
self.model = T5Model.from_pretrained(model_name, return_dict=True)
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.forward(
|
||||
input_ids, decoder_input_ids=decoder_input_ids
|
||||
)[0]
|
||||
|
||||
|
||||
def get_hf_seq2seq_model(name, import_args):
|
||||
m = HFSeq2SeqLanguageModel(name)
|
||||
encoded_input_ids = m.preprocess_input(
|
||||
"Studies have been shown that owning a dog is good for you"
|
||||
).input_ids
|
||||
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
@@ -141,27 +191,55 @@ class VisionModule(torch.nn.Module):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
def get_vision_model(torch_model):
|
||||
def get_vision_model(torch_model, import_args):
|
||||
import torchvision.models as models
|
||||
|
||||
default_image_size = (224, 224)
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
"alexnet": (models.alexnet(weights="DEFAULT"), default_image_size),
|
||||
"resnet18": (models.resnet18(weights="DEFAULT"), default_image_size),
|
||||
"resnet50": (models.resnet50(weights="DEFAULT"), default_image_size),
|
||||
"resnet50_fp16": (
|
||||
models.resnet50(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"resnet101": (models.resnet101(weights="DEFAULT"), default_image_size),
|
||||
"squeezenet1_0": (
|
||||
models.squeezenet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"wide_resnet50_2": (
|
||||
models.wide_resnet50_2(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mobilenet_v3_small": (
|
||||
models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mnasnet1_0": (
|
||||
models.mnasnet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
# EfficientNet input image size varies on the size of the model.
|
||||
"efficientnet_b0": (
|
||||
models.efficientnet_b0(weights="DEFAULT"),
|
||||
(224, 224),
|
||||
),
|
||||
"efficientnet_b7": (
|
||||
models.efficientnet_b7(weights="DEFAULT"),
|
||||
(600, 600),
|
||||
),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
fp16_model = None
|
||||
if "fp16" in torch_model:
|
||||
fp16_model = True
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
torch_model, input_image_size = vision_models_dict[torch_model]
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
test_input = torch.randn(
|
||||
int(import_args["batch_size"]), 3, *input_image_size
|
||||
)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
@@ -202,13 +280,14 @@ class BertHalfPrecisionModel(torch.nn.Module):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_fp16_model(torch_model):
|
||||
def get_fp16_model(torch_model, import_args):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
modelname = torch_model.replace("_fp16", "")
|
||||
model = BertHalfPrecisionModel(modelname)
|
||||
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
||||
text = "Replace me by any text you like."
|
||||
text = [text] * int(import_args["batch_size"])
|
||||
test_input_fp16 = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
################################## MHLO/TF models #########################################
|
||||
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
|
||||
keras_models = ["resnet50", "efficientnet-v2-s"]
|
||||
keras_models = [
|
||||
"resnet50",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
"efficientnet-v2-s",
|
||||
]
|
||||
maskedlm_models = [
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
@@ -32,36 +31,61 @@ maskedlm_models = [
|
||||
"hf-internal-testing/tiny-random-flaubert",
|
||||
"xlm-roberta",
|
||||
]
|
||||
causallm_models = [
|
||||
"gpt2",
|
||||
]
|
||||
tfhf_models = [
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
]
|
||||
tfhf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
img_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
"facebook/convnext-tiny-224",
|
||||
]
|
||||
|
||||
|
||||
def get_tf_model(name):
|
||||
def get_tf_model(name, import_args):
|
||||
if name in keras_models:
|
||||
return get_keras_model(name)
|
||||
return get_keras_model(name, import_args)
|
||||
elif name in maskedlm_models:
|
||||
return get_causal_lm_model(name)
|
||||
return get_masked_lm_model(name, import_args)
|
||||
elif name in causallm_models:
|
||||
return get_causal_lm_model(name, import_args)
|
||||
elif name in tfhf_models:
|
||||
return get_TFhf_model(name)
|
||||
return get_TFhf_model(name, import_args)
|
||||
elif name in img_models:
|
||||
return get_causal_image_model(name)
|
||||
return get_causal_image_model(name, import_args)
|
||||
elif name in tfhf_seq2seq_models:
|
||||
return get_tfhf_seq2seq_model(name, import_args)
|
||||
else:
|
||||
raise Exception(
|
||||
"TF model not found! Please check that the modelname has been input correctly."
|
||||
)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
##################### Tensorflow Hugging Face Bert Models ###################################
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
BERT_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -81,27 +105,37 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
def get_TFhf_model(name):
|
||||
def get_TFhf_model(name, import_args):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
text = [text] * BATCH_SIZE
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
max_length=BERT_MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
test_input = [
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["attention_mask"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["token_type_ids"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
]
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -115,34 +149,40 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
|
||||
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
|
||||
|
||||
|
||||
# Tokenizer for language models
|
||||
def preprocess_input(
|
||||
model_name, text="This is just used to compile the model"
|
||||
model_name, max_length, text="This is just used to compile the model"
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = [text] * BATCH_SIZE
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_length,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
|
||||
class MaskedLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(MaskedLM, self).__init__()
|
||||
@@ -156,19 +196,143 @@ class MaskedLM(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
def get_masked_lm_model(
|
||||
hf_name, import_args, text="Hello, this is the default text."
|
||||
):
|
||||
model = MaskedLM(hf_name)
|
||||
encoded_input = preprocess_input(hf_name, text)
|
||||
encoded_input = preprocess_input(
|
||||
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
|
||||
)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Causal LM Models ###################################
|
||||
|
||||
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
|
||||
|
||||
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
|
||||
|
||||
input_signature_causallm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
|
||||
# For more background, see:
|
||||
# https://huggingface.co/blog/tf-xla-generate
|
||||
class CausalLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(CausalLM, self).__init__()
|
||||
# Decoder-only models need left padding.
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, padding_side="left", pad_token="</s>"
|
||||
)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(
|
||||
input_ids=x, attention_mask=y
|
||||
)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(
|
||||
hf_name, import_args, text="Hello, this is the default text."
|
||||
):
|
||||
model = CausalLM(hf_name)
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input = model.preprocess_input(batched_text)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
input_signature_t5 = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="input_ids",
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="attention_mask",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TFHFSeq2SeqLanguageModel(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(TFHFSeq2SeqLanguageModel, self).__init__()
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFT5Model,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_t5, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
def get_tfhf_seq2seq_model(name, import_args):
|
||||
m = TFHFSeq2SeqLanguageModel(name)
|
||||
text = "Studies have been shown that owning a dog is good for you"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
|
||||
text = "Studies show that"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
decoder_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorFlow Keras Resnet Models #########################################################
|
||||
# Static shape, including batch size (1).
|
||||
# Can be dynamic once dynamic shape support is ready.
|
||||
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
|
||||
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
|
||||
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
|
||||
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
@@ -195,25 +359,79 @@ class ResNetModule(tf.Module):
|
||||
return tf.keras.applications.resnet50.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetModule(tf.Module):
|
||||
class EfficientNetB0Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
super(EfficientNetB0Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_INPUT_SHAPE
|
||||
return EFFICIENTNET_B0_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetB7Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetB7Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_B7_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetV2SModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetV2SModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_V2_S_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
|
||||
@@ -224,12 +442,17 @@ def load_image(path_to_image, width, height, channels):
|
||||
image = tf.image.decode_image(image, channels=channels)
|
||||
image = tf.image.resize(image, (width, height))
|
||||
image = image[tf.newaxis, :]
|
||||
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
def get_keras_model(modelname, import_args):
|
||||
if modelname == "efficientnet-v2-s":
|
||||
model = EfficientNetModule()
|
||||
model = EfficientNetV2SModule()
|
||||
elif modelname == "efficientnet_b0":
|
||||
model = EfficientNetB0Module()
|
||||
elif modelname == "efficientnet_b7":
|
||||
model = EfficientNetB7Module()
|
||||
else:
|
||||
model = ResNetModule()
|
||||
|
||||
@@ -256,7 +479,7 @@ import requests
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_img_cls = [
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
@@ -304,11 +527,14 @@ def preprocess_input_image(model_name):
|
||||
)
|
||||
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
inputs["pixel_values"] = tf.tile(
|
||||
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
|
||||
)
|
||||
|
||||
return [inputs[str(*inputs)]]
|
||||
|
||||
|
||||
def get_causal_image_model(hf_name):
|
||||
def get_causal_image_model(hf_name, import_args):
|
||||
model = AutoModelImageClassfication(hf_name)
|
||||
test_input = preprocess_input_image(hf_name)
|
||||
# TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 1000), dtype=float32, numpy=
|
||||
|
||||
@@ -8,6 +8,7 @@ from parameterized import parameterized
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
from tank.generate_sharktank import NoImportException
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -48,7 +49,9 @@ def load_csv_and_convert(filename, gen=False):
|
||||
)
|
||||
# This is a pytest workaround
|
||||
if gen:
|
||||
with open("tank/dict_configs.py", "w+") as out:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
|
||||
) as out:
|
||||
out.write("ALL = [\n")
|
||||
for c in model_configs:
|
||||
out.write(str(c) + ",\n")
|
||||
@@ -68,7 +71,9 @@ def get_valid_test_params():
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
# results in strange pytest failures.
|
||||
load_csv_and_convert("tank/all_models.csv", True)
|
||||
load_csv_and_convert(
|
||||
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
|
||||
)
|
||||
from tank.dict_configs import ALL
|
||||
|
||||
config_list = ALL
|
||||
@@ -135,6 +140,9 @@ class SharkModuleTester:
|
||||
self.config = config
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
import_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
@@ -161,12 +169,27 @@ class SharkModuleTester:
|
||||
if "winograd" in self.config["flags"]:
|
||||
shark_args.use_winograd = True
|
||||
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
|
||||
dl_gen_attempts = 2
|
||||
for i in range(dl_gen_attempts):
|
||||
try:
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
import_args=import_config,
|
||||
)
|
||||
except NoImportException as err:
|
||||
pytest.xfail(
|
||||
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
|
||||
)
|
||||
except AssertionError as err:
|
||||
if i < dl_gen_attempts - 1:
|
||||
continue
|
||||
else:
|
||||
pytest.xfail(
|
||||
"Generating OTF may require exiting the subprocess for files to be available."
|
||||
)
|
||||
break
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
device=device,
|
||||
@@ -210,10 +233,12 @@ class SharkModuleTester:
|
||||
self.save_reproducers()
|
||||
|
||||
def benchmark_module(self, shark_module, inputs, dynamic, device):
|
||||
model_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
shark_args.enable_tf32 = False
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
@@ -222,6 +247,7 @@ class SharkModuleTester:
|
||||
dynamic,
|
||||
device,
|
||||
self.config["framework"],
|
||||
import_args=model_config,
|
||||
)
|
||||
|
||||
def save_reproducers(self):
|
||||
@@ -271,6 +297,9 @@ class SharkModuleTest(unittest.TestCase):
|
||||
@parameterized.expand(param_list, name_func=shark_test_name_func)
|
||||
def test_module(self, dynamic, device, config):
|
||||
self.module_tester = SharkModuleTester(config)
|
||||
self.module_tester.batch_size = self.pytestconfig.getoption(
|
||||
"batchsize"
|
||||
)
|
||||
self.module_tester.benchmark = self.pytestconfig.getoption("benchmark")
|
||||
self.module_tester.save_repro = self.pytestconfig.getoption(
|
||||
"save_repro"
|
||||
@@ -342,6 +371,10 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="Numerics issues: https://github.com/nod-ai/SHARK/issues/476"
|
||||
)
|
||||
if config["framework"] == "tf" and self.module_tester.batch_size != 1:
|
||||
pytest.xfail(
|
||||
reason="Configurable batch sizes temp. unavailable for tensorflow models."
|
||||
)
|
||||
safe_name = (
|
||||
f"{config['model_name']}_{config['framework']}_{dynamic}_{device}"
|
||||
)
|
||||
|
||||
@@ -19,3 +19,10 @@ facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
|
||||
|
@@ -1,4 +1,6 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
microsoft/MiniLM-L12-H384-uncased,True,hf,True,66M,"nlp;bert-variant;transformer-encoder","Large version has 12 layers; 384 hidden size; Smaller than BERTbase (66M params vs 109M params)"
|
||||
bert-base-uncased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
@@ -18,4 +20,4 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
|
||||
|
Reference in New Issue
Block a user