mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
64 Commits
20230308.5
...
20230324.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17a67897d1 | ||
|
|
da449b73aa | ||
|
|
0b0526699a | ||
|
|
4fac46f7bb | ||
|
|
49925950f1 | ||
|
|
807947c0c8 | ||
|
|
593428bda4 | ||
|
|
cede9b4fec | ||
|
|
c2360303f0 | ||
|
|
420366c1b8 | ||
|
|
d31bae488c | ||
|
|
c23fcf3748 | ||
|
|
7dbbb1726a | ||
|
|
8b8cc7fd33 | ||
|
|
e3c96a2b9d | ||
|
|
5e3f50647d | ||
|
|
7899e1803a | ||
|
|
d105246b9c | ||
|
|
90c958bca2 | ||
|
|
f99903e023 | ||
|
|
c6f44ef1b3 | ||
|
|
8dcd4d5aeb | ||
|
|
d319f4684e | ||
|
|
54d7b6d83e | ||
|
|
4a622532e5 | ||
|
|
650b2ada58 | ||
|
|
f87f8949f3 | ||
|
|
7dc9bf8148 | ||
|
|
ba48ff8d25 | ||
|
|
638840925c | ||
|
|
b661656c03 | ||
|
|
0225434389 | ||
|
|
7ffe20b1c2 | ||
|
|
d8f0c4655d | ||
|
|
7e8d3ec0df | ||
|
|
9c08eec565 | ||
|
|
2d2c523ac5 | ||
|
|
f17b3128c0 | ||
|
|
7c7e630099 | ||
|
|
2dd1491ec1 | ||
|
|
236357fb61 | ||
|
|
7bc38719de | ||
|
|
bdbe992769 | ||
|
|
e6b925e012 | ||
|
|
771120b76c | ||
|
|
a8ce7680db | ||
|
|
b6dcf2401b | ||
|
|
62b5a9fd49 | ||
|
|
2f133e9d5c | ||
|
|
f898a1d332 | ||
|
|
b94266d2b9 | ||
|
|
1b08242aaa | ||
|
|
691030fbab | ||
|
|
16ad7d57a3 | ||
|
|
c561ebf43c | ||
|
|
97fdff7f19 | ||
|
|
ce6d82eab2 | ||
|
|
b8f4b18951 | ||
|
|
b23d3aa584 | ||
|
|
495670d9b6 | ||
|
|
815e23a0b8 | ||
|
|
783538fe11 | ||
|
|
996c645f6a | ||
|
|
1f7d249a62 |
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py
|
||||
7
.github/workflows/test-models.yml
vendored
7
.github/workflows/test-models.yml
vendored
@@ -99,11 +99,12 @@ jobs:
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --line-length 79 --check .
|
||||
black --check .
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
|
||||
flake8 . --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
|
||||
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --exclude lit.cfg.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -168,6 +168,8 @@ shark_tmp/
|
||||
*.vmfb
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
*.csv
|
||||
reproducers/
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
@@ -182,3 +184,6 @@ models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
@@ -2,3 +2,5 @@ from apps.stable_diffusion.scripts.txt2img import txt2img_inf
|
||||
from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
|
||||
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
|
||||
from apps.stable_diffusion.scripts.upscaler import upscaler_inf
|
||||
from apps.stable_diffusion.scripts.train_lora_word import lora_train
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -13,36 +13,47 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
use_stencil: str
|
||||
|
||||
|
||||
img2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size but we need to ensure that
|
||||
# it conforms with our model contraints :-
|
||||
# Both width and height should be > 384 and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 384:
|
||||
n_size = 384
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
else:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
return new_image, n_width, n_height
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: Image,
|
||||
init_image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
@@ -60,10 +71,18 @@ def img2img_inf(
|
||||
use_stencil: str,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
global img2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -79,10 +98,6 @@ def img2img_inf(
|
||||
image = init_image.convert("RGB")
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -93,10 +108,14 @@ def img2img_inf(
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
@@ -105,6 +124,7 @@ def img2img_inf(
|
||||
if use_stencil is not None:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, width, height = resize_stencil(image)
|
||||
elif args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
@@ -119,6 +139,7 @@ def img2img_inf(
|
||||
args.precision = precision
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
@@ -127,10 +148,16 @@ def img2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_stencil,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
if not img2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -145,55 +172,65 @@ def img2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
args.use_tuned = False
|
||||
img2img_obj = StencilPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
global_obj.set_sd_obj(
|
||||
StencilPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
else:
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
global_obj.set_sd_obj(
|
||||
Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
img2img_obj.scheduler = schedulers[scheduler]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
img2img_obj.log = ""
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = img2img_obj.generate_images(
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
@@ -210,20 +247,18 @@ def img2img_inf(
|
||||
cpu_scheduling,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
img2img_obj.log += "\n"
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, strength={args.strength}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -236,6 +271,7 @@ if __name__ == "__main__":
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
@@ -243,6 +279,7 @@ if __name__ == "__main__":
|
||||
if use_stencil:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, args.width, args.height = resize_stencil(image)
|
||||
elif args.scheduler != "PNDM":
|
||||
if "Shark" in args.scheduler:
|
||||
print(
|
||||
@@ -257,9 +294,7 @@ if __name__ == "__main__":
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
@@ -279,6 +314,8 @@ if __name__ == "__main__":
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
else:
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
@@ -295,6 +332,8 @@ if __name__ == "__main__":
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -12,24 +11,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
inpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -58,10 +42,18 @@ def inpaint_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
global inpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -72,10 +64,6 @@ def inpaint_inf(
|
||||
args.mask_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -86,16 +74,21 @@ def inpaint_inf(
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"inpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
@@ -104,10 +97,17 @@ def inpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if not inpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -122,36 +122,42 @@ def inpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
inpaint_obj.scheduler = schedulers[scheduler]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
inpaint_obj.log = ""
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = inpaint_obj.generate_images(
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
@@ -169,20 +175,18 @@ def inpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
inpaint_obj.log += "\n"
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -213,18 +217,21 @@ if __name__ == "__main__":
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
@@ -12,24 +11,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
outpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -40,7 +24,7 @@ init_import_mlir = args.import_mlir
|
||||
def outpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: Image,
|
||||
init_image,
|
||||
pixels: int,
|
||||
mask_blur: int,
|
||||
directions: list,
|
||||
@@ -61,10 +45,18 @@ def outpaint_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
global outpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -74,10 +66,6 @@ def outpaint_inf(
|
||||
args.img_path = "not none"
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -88,16 +76,21 @@ def outpaint_inf(
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"outpaint",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
@@ -106,10 +99,17 @@ def outpaint_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if not outpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -124,27 +124,30 @@ def outpaint_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
outpaint_obj.scheduler = schedulers[scheduler]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
outpaint_obj.log = ""
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
@@ -154,10 +157,11 @@ def outpaint_inf(
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = outpaint_obj.generate_images(
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
@@ -180,20 +184,18 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
outpaint_obj.log += "\n"
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -232,6 +234,7 @@ if __name__ == "__main__":
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
674
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
674
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
@@ -0,0 +1,674 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# HuggingFace Token
|
||||
# YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
# Import required libraries
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
import torch_mlir
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
clear_all,
|
||||
)
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class LoraDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
set="train",
|
||||
prompt="myloraprompt",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.prompt = prompt
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
def lora_train(
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
training_images_dir: str,
|
||||
lora_save_dir: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
print(
|
||||
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
||||
)
|
||||
print(
|
||||
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.steps = steps
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.training_images_dir = training_images_dir
|
||||
args.lora_save_dir = lora_save_dir
|
||||
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
device_str = device.split("=>", 1)[1].strip().split("://")
|
||||
if len(device_str) > 1:
|
||||
device_str = device_str[0] + ":" + device_str[1]
|
||||
else:
|
||||
device_str = device_str[0]
|
||||
args.device = device_str
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="unet"
|
||||
)
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze everything but LoRA
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
freeze_params(text_encoder.parameters())
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
text_encoder.to(args.device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None
|
||||
if name.endswith("attn1.processor")
|
||||
else unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[
|
||||
block_id
|
||||
]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = LoraDataset(
|
||||
data_root=args.training_images_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
prompt=args.prompts[0],
|
||||
repeats=100,
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
args.hf_model_id, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": steps,
|
||||
"train_batch_size": batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
# prediction = torch.from_numpy(res)
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
lora_layers.parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {train_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params__ = [
|
||||
i for i in text_encoder.get_input_embeddings().parameters()
|
||||
]
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
training_function()
|
||||
|
||||
# Save the lora weights
|
||||
unet.save_attn_procs(args.lora_save_dir)
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
if len(args.prompts) != 1:
|
||||
print("Need exactly one prompt for the LoRA word")
|
||||
lora_train(
|
||||
args.prompts[0],
|
||||
args.height,
|
||||
args.width,
|
||||
args.training_steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.batch_count,
|
||||
args.batch_size,
|
||||
args.scheduler,
|
||||
"None",
|
||||
args.hf_model_id,
|
||||
args.precision,
|
||||
args.device,
|
||||
args.max_length,
|
||||
args.training_images_dir,
|
||||
args.lora_save_dir,
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
import sys
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
@@ -11,24 +10,9 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
txt2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
@@ -54,10 +38,18 @@ def txt2img_inf(
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
global txt2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -66,10 +58,6 @@ def txt2img_inf(
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
@@ -80,16 +68,21 @@ def txt2img_inf(
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"txt2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
@@ -98,10 +91,17 @@ def txt2img_inf(
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if not txt2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
@@ -117,35 +117,40 @@ def txt2img_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
txt2img_obj.scheduler = schedulers[scheduler]
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
txt2img_obj.log = ""
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = txt2img_obj.generate_images(
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
@@ -159,25 +164,20 @@ def txt2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
txt2img_obj.log += "\n"
|
||||
yield generated_imgs, generated_imgs[0], txt2img_obj.log
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
)
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
# text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
|
||||
yield generated_imgs, text_output
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -190,21 +190,23 @@ if __name__ == "__main__":
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
273
apps/stable_diffusion/scripts/upscaler.py
Normal file
273
apps/stable_diffusion/scripts/upscaler.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def upscaler_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
noise_level: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
|
||||
if init_image is None:
|
||||
return None, "An Initial Image is required"
|
||||
image = init_image.convert("RGB").resize((height, width))
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = get_custom_model_pathfile(custom_model)
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
args.height = 128
|
||||
args.width = 128
|
||||
new_config_obj = Config(
|
||||
"upscaler",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
args.height,
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
global_obj.set_sd_obj(
|
||||
UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
|
||||
"DDPM"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
for i in range(0, width, 128):
|
||||
for j in range(0, height, 128):
|
||||
box = (j, i, j + 128, i + 128)
|
||||
upscaled_image = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
low_res_img.crop(box),
|
||||
batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
yield generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = (
|
||||
Image.open(args.img_path)
|
||||
.convert("RGB")
|
||||
.resize((args.height, args.width))
|
||||
)
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
upscaler_obj = UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = upscaler_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.noise_level,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += upscaler_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
extra_info = {"NOISE LEVEL": args.noise_level}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
@@ -20,7 +21,9 @@ datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -38,13 +41,15 @@ binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
@@ -21,6 +22,8 @@ datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -36,13 +39,15 @@ binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/txt2img.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -12,5 +12,6 @@ from apps.stable_diffusion.src.pipelines import (
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
StencilPipeline,
|
||||
UpscalerPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
import safetensors.torch
|
||||
import traceback
|
||||
import sys
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -16,6 +17,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,13 +33,23 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
elif shape[i] == "width":
|
||||
new_shape.append(width)
|
||||
elif isinstance(shape[i], str):
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(batch_size * mul_val)
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(height * mul_val)
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(width * mul_val)
|
||||
if "*" in shape[i]:
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(batch_size * mul_val)
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(height * mul_val)
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(width * mul_val)
|
||||
elif "/" in shape[i]:
|
||||
import math
|
||||
div_val = int(shape[i].split("/")[1])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(math.ceil(batch_size / div_val))
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(math.ceil(height / div_val))
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(math.ceil(width / div_val))
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
@@ -81,7 +94,14 @@ class SharkifyStableDiffusionModel:
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
is_inpaint: bool = False
|
||||
debug: bool = False,
|
||||
sharktank_dir: str = "",
|
||||
generate_vmfb: bool = True,
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -89,6 +109,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -102,7 +123,8 @@ class SharkifyStableDiffusionModel:
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
str(batch_size)
|
||||
"_"
|
||||
+ str(batch_size)
|
||||
+ "_"
|
||||
+ str(max_len)
|
||||
+ "_"
|
||||
@@ -112,12 +134,23 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
print(f'use_tuned? sharkify: {use_tuned}')
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
self.is_upscaler = is_upscaler
|
||||
self.use_stencil = get_stencil_model_id(use_stencil)
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.use_lora = use_lora
|
||||
|
||||
print(self.model_name)
|
||||
self.debug = debug
|
||||
self.sharktank_dir = sharktank_dir
|
||||
self.generate_vmfb = generate_vmfb
|
||||
|
||||
def get_extended_name_for_all_model(self, mask_to_fetch):
|
||||
model_name = {}
|
||||
@@ -141,10 +174,10 @@ class SharkifyStableDiffusionModel:
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
sys.exit("please specify max_len in the range [32, 77].")
|
||||
if not (width % 8 == 0 and width >= 384):
|
||||
sys.exit("width should be greater than 384 and multiple of 8")
|
||||
if not (height % 8 == 0 and height >= 384):
|
||||
sys.exit("height should be greater than 384 and multiple of 8")
|
||||
if not (width % 8 == 0 and width >= 128):
|
||||
sys.exit("width should be greater than 128 and multiple of 8")
|
||||
if not (height % 8 == 0 and height >= 128):
|
||||
sys.exit("height should be greater than 128 and multiple of 8")
|
||||
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
@@ -170,6 +203,7 @@ class SharkifyStableDiffusionModel:
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
@@ -212,27 +246,63 @@ class SharkifyStableDiffusionModel:
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_vae_upscaler(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
x = (x / 2 + 0.5).clamp(0, 1)
|
||||
return x
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
def get_controlled_unet(self):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
"takuma104/control_sd15_canny", # TODO: ADD with model ID
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -271,18 +341,18 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_controlled_unet
|
||||
|
||||
def get_control_net(self):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.cnet = ControlNetModel.from_pretrained(
|
||||
"takuma104/control_sd15_canny", # TODO: ADD with model ID
|
||||
subfolder="controlnet",
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
@@ -325,18 +395,21 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_cnet
|
||||
|
||||
def get_unet(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
@@ -360,6 +433,53 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_unet_upscaler(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, noise_level):
|
||||
unet_out = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
@@ -372,28 +492,41 @@ class SharkifyStableDiffusionModel:
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_unet
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name=self.model_name["clip"],
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
@@ -421,6 +554,7 @@ class SharkifyStableDiffusionModel:
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode, need_stencil):
|
||||
self.base_model_id = base_model_id
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
@@ -428,6 +562,9 @@ class SharkifyStableDiffusionModel:
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
if self.is_upscaler:
|
||||
return self.get_clip(), self.get_unet_upscaler(), self.get_vae_upscaler()
|
||||
|
||||
compiled_controlnet = None
|
||||
compiled_controlled_unet = None
|
||||
compiled_unet = None
|
||||
@@ -435,7 +572,12 @@ class SharkifyStableDiffusionModel:
|
||||
compiled_controlnet = self.get_control_net()
|
||||
compiled_controlled_unet = self.get_controlled_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
if self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
compiled_unet = get_unet()
|
||||
else:
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
@@ -453,8 +595,8 @@ class SharkifyStableDiffusionModel:
|
||||
# Step 1:
|
||||
# -- Fetch all vmfbs for the model, if present, else delete the lot.
|
||||
need_vae_encode, need_stencil = False, False
|
||||
if args.img_path is not None:
|
||||
if args.use_stencil is not None:
|
||||
if not self.is_upscaler and args.img_path is not None:
|
||||
if self.use_stencil is not None:
|
||||
need_stencil = True
|
||||
else:
|
||||
need_vae_encode = True
|
||||
@@ -512,6 +654,7 @@ class SharkifyStableDiffusionModel:
|
||||
else:
|
||||
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode, need_stencil)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
|
||||
@@ -20,6 +20,15 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
bucket_key = "gs://shark_tank/prashant_nod"
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
@@ -41,6 +50,12 @@ def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
|
||||
# TODO: Get the quantize model from model_db.json
|
||||
if args.use_quantize == "int8":
|
||||
bk, mk, flags = get_quantize_model()
|
||||
return get_shark_model(bk, mk, flags)
|
||||
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
|
||||
@@ -13,3 +13,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpain
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
|
||||
StencilPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
|
||||
UpscalerPipeline,
|
||||
)
|
||||
|
||||
@@ -20,9 +20,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils i
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
|
||||
@@ -0,0 +1,310 @@
|
||||
import inspect
|
||||
import torch
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class UpscalerPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
low_res_scheduler: Union[
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
latents = 1 / 0.08333 * (latents.float())
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.vae("forward", (latents_numpy,))
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
end_profiling(profile_device)
|
||||
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
image,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
noise_level,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
extra_step_kwargs,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
if cpu_scheduling:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image).to(dtype)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long)
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
).to(dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
image = torch.cat([image] * 2)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
height, width = image.shape[2:]
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
eta = 0.0
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
image=image,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
noise_level=noise_level,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
@@ -30,6 +31,9 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
|
||||
|
||||
class StableDiffusionPipeline:
|
||||
def __init__(
|
||||
@@ -57,6 +61,7 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -161,15 +166,6 @@ class StableDiffusionPipeline:
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
down_block_res_samples = control[0:12]
|
||||
mid_block_res_sample = control[12:]
|
||||
down_block_res_samples = [
|
||||
down_block_res_sample * controlnet_conditioning_scale
|
||||
for down_block_res_sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = (
|
||||
mid_block_res_sample[0] * controlnet_conditioning_scale
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
@@ -181,19 +177,19 @@ class StableDiffusionPipeline:
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
down_block_res_samples[0],
|
||||
down_block_res_samples[1],
|
||||
down_block_res_samples[2],
|
||||
down_block_res_samples[3],
|
||||
down_block_res_samples[4],
|
||||
down_block_res_samples[5],
|
||||
down_block_res_samples[6],
|
||||
down_block_res_samples[7],
|
||||
down_block_res_samples[8],
|
||||
down_block_res_samples[9],
|
||||
down_block_res_samples[10],
|
||||
down_block_res_samples[11],
|
||||
mid_block_res_sample,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -234,6 +230,7 @@ class StableDiffusionPipeline:
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
@@ -283,6 +280,9 @@ class StableDiffusionPipeline:
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -317,13 +317,22 @@ class StableDiffusionPipeline:
|
||||
use_base_vae: bool,
|
||||
use_tuned: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
use_stencil: bool = False,
|
||||
debug: bool = False,
|
||||
use_stencil: str = None,
|
||||
use_lora: str = "",
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
is_inpaint = cls.__name__ in [
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
if import_mlir:
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
if import_mlir or use_lora:
|
||||
if not import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
|
||||
)
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
@@ -336,7 +345,12 @@ class StableDiffusionPipeline:
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
debug=debug,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
@@ -352,6 +366,12 @@ class StableDiffusionPipeline:
|
||||
return cls(
|
||||
controlnet, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
if cls.__name__ in ["UpscalerPipeline"]:
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(
|
||||
vae, clip, get_tokenizer(), unet, scheduler, ddpm_scheduler
|
||||
)
|
||||
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
try:
|
||||
@@ -392,6 +412,7 @@ class StableDiffusionPipeline:
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
)
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
@@ -19,6 +20,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
|
||||
@@ -90,8 +90,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
|
||||
@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_shark_model,
|
||||
@@ -31,4 +32,6 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_extended_name,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,52 @@
|
||||
{
|
||||
"stabilityai/stable-diffusion-x4-upscaler": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
7,
|
||||
"8*height",
|
||||
"8*width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"noise_level": {
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-1": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
@@ -110,7 +158,7 @@
|
||||
"dtype": "f32"
|
||||
},
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, 512, 512],
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
@@ -143,55 +191,55 @@
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control1": {
|
||||
"shape": [2, 320, 64, 64],
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control2": {
|
||||
"shape": [2, 320, 64, 64],
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control3": {
|
||||
"shape": [2, 320, 64, 64],
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control4": {
|
||||
"shape": [2, 320, 32, 32],
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control5": {
|
||||
"shape": [2, 640, 32, 32],
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control6": {
|
||||
"shape": [2, 640, 32, 32],
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control7": {
|
||||
"shape": [2, 640, 16, 16],
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control8": {
|
||||
"shape": [2, 1280, 16, 16],
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control9": {
|
||||
"shape": [2, 1280, 16, 16],
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control10": {
|
||||
"shape": [2, 1280, 8, 8],
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control11": {
|
||||
"shape": [2, 1280, 8, 8],
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control12": {
|
||||
"shape": [2, 1280, 8, 8],
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control13": {
|
||||
"shape": [2, 1280, 8, 8],
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
@@ -333,4 +381,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,10 +18,15 @@
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
|
||||
"stablediffusion/v1_4/unet/fp32/length_64/untuned":"unet_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
|
||||
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
|
||||
"stablediffusion/v1_4/vae/fp32/length_64/untuned":"vae_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
|
||||
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp32_CompVis_stable_diffusion_v1_4",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
|
||||
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
|
||||
|
||||
@@ -76,18 +76,19 @@ def load_winograd_configs():
|
||||
return winograd_config_dir
|
||||
|
||||
|
||||
def load_lower_configs():
|
||||
def load_lower_configs(base_model_id=None):
|
||||
from apps.stable_diffusion.src.models import get_variant_version
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
)
|
||||
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
if not base_model_id:
|
||||
if args.ckpt_loc != "":
|
||||
base_model_id = fetch_and_update_base_model_id(args.ckpt_loc)
|
||||
else:
|
||||
base_model_id = fetch_and_update_base_model_id(args.hf_model_id)
|
||||
if base_model_id == "":
|
||||
base_model_id = args.hf_model_id
|
||||
|
||||
variant, version = get_variant_version(base_model_id)
|
||||
|
||||
@@ -114,7 +115,14 @@ def load_lower_configs():
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
and args.width == 768
|
||||
):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
@@ -212,7 +220,7 @@ def annotate_with_lower_configs(
|
||||
return bytecode
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name):
|
||||
def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
device = get_device()
|
||||
if args.annotation_model == "unet" and device == "vulkan":
|
||||
use_winograd = True
|
||||
@@ -220,7 +228,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
winograd_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
@@ -232,7 +240,7 @@ def sd_model_annotation(mlir_model, model_name):
|
||||
)
|
||||
else:
|
||||
use_winograd = False
|
||||
lowering_config_dir = load_lower_configs()
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -6,6 +7,13 @@ def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
@@ -61,7 +69,7 @@ p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(384, 769, 8),
|
||||
choices=range(128, 769, 8),
|
||||
help="the height of the output image.",
|
||||
)
|
||||
|
||||
@@ -69,7 +77,7 @@ p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(384, 769, 8),
|
||||
choices=range(128, 769, 8),
|
||||
help="the width of the output image.",
|
||||
)
|
||||
|
||||
@@ -80,6 +88,13 @@ p.add_argument(
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_level",
|
||||
type=int,
|
||||
default=20,
|
||||
help="the value to be used for noise level of upscaler.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
@@ -94,6 +109,31 @@ p.add_argument(
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The no. of steps to train",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
@@ -289,10 +329,25 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny"],
|
||||
choices=["canny", "openpose", "scribble"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
|
||||
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -313,7 +368,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
default="2073741824",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
@@ -402,6 +457,12 @@ p.add_argument(
|
||||
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
|
||||
)
|
||||
##############################################################################
|
||||
### Web UI flags
|
||||
##############################################################################
|
||||
@@ -461,3 +522,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
os.getcwd(), args.hf_model_id.replace("/", "_")
|
||||
)
|
||||
|
||||
2
apps/stable_diffusion/src/utils/stencils/__init__.py
Normal file
2
apps/stable_diffusion/src/utils/stencils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
@@ -0,0 +1,62 @@
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# from annotator.util import annotator_ckpts_path
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.body import Body
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
draw_bodypose,
|
||||
draw_handpose,
|
||||
handDetect,
|
||||
)
|
||||
|
||||
|
||||
body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
|
||||
hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
|
||||
|
||||
|
||||
class OpenposeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
body_modelpath = ckpt_path / "body_pose_model.pth"
|
||||
hand_modelpath = ckpt_path / "hand_pose_model.pth"
|
||||
|
||||
if not body_modelpath.is_file():
|
||||
r = requests.get(body_model_path, allow_redirects=True)
|
||||
open(body_modelpath, "wb").write(r.content)
|
||||
if not hand_modelpath.is_file():
|
||||
r = requests.get(hand_model_path, allow_redirects=True)
|
||||
open(hand_modelpath, "wb").write(r.content)
|
||||
|
||||
self.body_estimation = Body(body_modelpath)
|
||||
self.hand_estimation = Hand(hand_modelpath)
|
||||
|
||||
def __call__(self, oriImg, hand=False):
|
||||
oriImg = oriImg[:, :, ::-1].copy()
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.body_estimation(oriImg)
|
||||
canvas = np.zeros_like(oriImg)
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
if hand:
|
||||
hands_list = handDetect(candidate, subset, oriImg)
|
||||
all_hand_peaks = []
|
||||
for x, y, w, is_left in hands_list:
|
||||
peaks = self.hand_estimation(
|
||||
oriImg[y : y + w, x : x + w, :]
|
||||
)
|
||||
peaks[:, 0] = np.where(
|
||||
peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x
|
||||
)
|
||||
peaks[:, 1] = np.where(
|
||||
peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y
|
||||
)
|
||||
all_hand_peaks.append(peaks)
|
||||
canvas = draw_handpose(canvas, all_hand_peaks)
|
||||
return canvas, dict(
|
||||
candidate=candidate.tolist(), subset=subset.tolist()
|
||||
)
|
||||
499
apps/stable_diffusion/src/utils/stencils/openpose/body.py
Normal file
499
apps/stable_diffusion/src/utils/stencils/openpose/body.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
make_layers,
|
||||
transfer,
|
||||
padRightDownCorner,
|
||||
)
|
||||
|
||||
|
||||
class BodyPoseModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(BodyPoseModel, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv5_5_CPM_L1",
|
||||
"conv5_5_CPM_L2",
|
||||
"Mconv7_stage2_L1",
|
||||
"Mconv7_stage2_L2",
|
||||
"Mconv7_stage3_L1",
|
||||
"Mconv7_stage3_L2",
|
||||
"Mconv7_stage4_L1",
|
||||
"Mconv7_stage4_L2",
|
||||
"Mconv7_stage5_L1",
|
||||
"Mconv7_stage5_L2",
|
||||
"Mconv7_stage6_L1",
|
||||
"Mconv7_stage6_L1",
|
||||
]
|
||||
blocks = {}
|
||||
block0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3_CPM", [512, 256, 3, 1, 1]),
|
||||
("conv4_4_CPM", [256, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
# Stage 1
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_2 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
blocks["block1_1"] = block1_1
|
||||
blocks["block1_2"] = block1_2
|
||||
|
||||
self.model0 = make_layers(block0, no_relu_layers)
|
||||
|
||||
# Stages 2 - 6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d_1" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks["block%d_2" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2_1 = blocks["block2_1"]
|
||||
self.model3_1 = blocks["block3_1"]
|
||||
self.model4_1 = blocks["block4_1"]
|
||||
self.model5_1 = blocks["block5_1"]
|
||||
self.model6_1 = blocks["block6_1"]
|
||||
|
||||
self.model1_2 = blocks["block1_2"]
|
||||
self.model2_2 = blocks["block2_2"]
|
||||
self.model3_2 = blocks["block3_2"]
|
||||
self.model4_2 = blocks["block4_2"]
|
||||
self.model5_2 = blocks["block5_2"]
|
||||
self.model6_2 = blocks["block6_2"]
|
||||
|
||||
def forward(self, x):
|
||||
out1 = self.model0(x)
|
||||
|
||||
out1_1 = self.model1_1(out1)
|
||||
out1_2 = self.model1_2(out1)
|
||||
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
||||
|
||||
out2_1 = self.model2_1(out2)
|
||||
out2_2 = self.model2_2(out2)
|
||||
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
||||
|
||||
out3_1 = self.model3_1(out3)
|
||||
out3_2 = self.model3_2(out3)
|
||||
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
||||
|
||||
out4_1 = self.model4_1(out4)
|
||||
out4_2 = self.model4_2(out4)
|
||||
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
||||
|
||||
out5_1 = self.model5_1(out5)
|
||||
out5_2 = self.model5_2(out5)
|
||||
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
||||
|
||||
out6_1 = self.model6_1(out6)
|
||||
out6_2 = self.model6_2(out6)
|
||||
|
||||
return out6_1, out6_2
|
||||
|
||||
|
||||
class Body(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = BodyPoseModel()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
model_dict = transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, oriImg):
|
||||
scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre1 = 0.1
|
||||
thre2 = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
||||
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = cv2.resize(
|
||||
oriImg,
|
||||
(0, 0),
|
||||
fx=scale,
|
||||
fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
imageToTest_padded, pad = padRightDownCorner(
|
||||
imageToTest, stride, padValue
|
||||
)
|
||||
im = (
|
||||
np.transpose(
|
||||
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
||||
(3, 2, 0, 1),
|
||||
)
|
||||
/ 256
|
||||
- 0.5
|
||||
)
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
if torch.cuda.is_available():
|
||||
data = data.cuda()
|
||||
with torch.no_grad():
|
||||
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
||||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
||||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(
|
||||
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
|
||||
) # output 1 is heatmaps
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
heatmap = heatmap[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = np.transpose(
|
||||
np.squeeze(Mconv7_stage6_L1), (1, 2, 0)
|
||||
) # output 0 is PAFs
|
||||
paf = cv2.resize(
|
||||
paf,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
paf = paf[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
paf = cv2.resize(
|
||||
paf,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
||||
paf_avg += +paf / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
|
||||
for part in range(18):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
|
||||
map_left = np.zeros(one_heatmap.shape)
|
||||
map_left[1:, :] = one_heatmap[:-1, :]
|
||||
map_right = np.zeros(one_heatmap.shape)
|
||||
map_right[:-1, :] = one_heatmap[1:, :]
|
||||
map_up = np.zeros(one_heatmap.shape)
|
||||
map_up[:, 1:] = one_heatmap[:, :-1]
|
||||
map_down = np.zeros(one_heatmap.shape)
|
||||
map_down[:, :-1] = one_heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce(
|
||||
(
|
||||
one_heatmap >= map_left,
|
||||
one_heatmap >= map_right,
|
||||
one_heatmap >= map_up,
|
||||
one_heatmap >= map_down,
|
||||
one_heatmap > thre1,
|
||||
)
|
||||
)
|
||||
peaks = list(
|
||||
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
|
||||
) # note reverse
|
||||
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
||||
peak_id = range(peak_counter, peak_counter + len(peaks))
|
||||
peaks_with_score_and_id = [
|
||||
peaks_with_score[i] + (peak_id[i],)
|
||||
for i in range(len(peak_id))
|
||||
]
|
||||
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
peak_counter += len(peaks)
|
||||
|
||||
# find connection in the specified sequence, center 29 is in the position 15
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
# the middle joints heatmap correpondence
|
||||
mapIdx = [
|
||||
[31, 32],
|
||||
[39, 40],
|
||||
[33, 34],
|
||||
[35, 36],
|
||||
[41, 42],
|
||||
[43, 44],
|
||||
[19, 20],
|
||||
[21, 22],
|
||||
[23, 24],
|
||||
[25, 26],
|
||||
[27, 28],
|
||||
[29, 30],
|
||||
[47, 48],
|
||||
[49, 50],
|
||||
[53, 54],
|
||||
[51, 52],
|
||||
[55, 56],
|
||||
[37, 38],
|
||||
[45, 46],
|
||||
]
|
||||
|
||||
connection_all = []
|
||||
special_k = []
|
||||
mid_num = 10
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
||||
candA = all_peaks[limbSeq[k][0] - 1]
|
||||
candB = all_peaks[limbSeq[k][1] - 1]
|
||||
nA = len(candA)
|
||||
nB = len(candB)
|
||||
indexA, indexB = limbSeq[k]
|
||||
if nA != 0 and nB != 0:
|
||||
connection_candidate = []
|
||||
for i in range(nA):
|
||||
for j in range(nB):
|
||||
vec = np.subtract(candB[j][:2], candA[i][:2])
|
||||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
||||
norm = max(0.001, norm)
|
||||
vec = np.divide(vec, norm)
|
||||
|
||||
startend = list(
|
||||
zip(
|
||||
np.linspace(
|
||||
candA[i][0], candB[j][0], num=mid_num
|
||||
),
|
||||
np.linspace(
|
||||
candA[i][1], candB[j][1], num=mid_num
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
vec_x = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
0,
|
||||
]
|
||||
for I in range(len(startend))
|
||||
]
|
||||
)
|
||||
vec_y = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
1,
|
||||
]
|
||||
for I in range(len(startend))
|
||||
]
|
||||
)
|
||||
|
||||
score_midpts = np.multiply(
|
||||
vec_x, vec[0]
|
||||
) + np.multiply(vec_y, vec[1])
|
||||
score_with_dist_prior = sum(score_midpts) / len(
|
||||
score_midpts
|
||||
) + min(0.5 * oriImg.shape[0] / norm - 1, 0)
|
||||
criterion1 = len(
|
||||
np.nonzero(score_midpts > thre2)[0]
|
||||
) > 0.8 * len(score_midpts)
|
||||
criterion2 = score_with_dist_prior > 0
|
||||
if criterion1 and criterion2:
|
||||
connection_candidate.append(
|
||||
[
|
||||
i,
|
||||
j,
|
||||
score_with_dist_prior,
|
||||
score_with_dist_prior
|
||||
+ candA[i][2]
|
||||
+ candB[j][2],
|
||||
]
|
||||
)
|
||||
|
||||
connection_candidate = sorted(
|
||||
connection_candidate, key=lambda x: x[2], reverse=True
|
||||
)
|
||||
connection = np.zeros((0, 5))
|
||||
for c in range(len(connection_candidate)):
|
||||
i, j, s = connection_candidate[c][0:3]
|
||||
if i not in connection[:, 3] and j not in connection[:, 4]:
|
||||
connection = np.vstack(
|
||||
[connection, [candA[i][3], candB[j][3], s, i, j]]
|
||||
)
|
||||
if len(connection) >= min(nA, nB):
|
||||
break
|
||||
|
||||
connection_all.append(connection)
|
||||
else:
|
||||
special_k.append(k)
|
||||
connection_all.append([])
|
||||
|
||||
# last number in each row is the total parts number of that person
|
||||
# the second last number in each row is the score of the overall configuration
|
||||
subset = -1 * np.ones((0, 20))
|
||||
candidate = np.array(
|
||||
[item for sublist in all_peaks for item in sublist]
|
||||
)
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
if k not in special_k:
|
||||
partAs = connection_all[k][:, 0]
|
||||
partBs = connection_all[k][:, 1]
|
||||
indexA, indexB = np.array(limbSeq[k]) - 1
|
||||
|
||||
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
||||
found = 0
|
||||
subset_idx = [-1, -1]
|
||||
for j in range(len(subset)): # 1:size(subset,1):
|
||||
if (
|
||||
subset[j][indexA] == partAs[i]
|
||||
or subset[j][indexB] == partBs[i]
|
||||
):
|
||||
subset_idx[found] = j
|
||||
found += 1
|
||||
|
||||
if found == 1:
|
||||
j = subset_idx[0]
|
||||
if subset[j][indexB] != partBs[i]:
|
||||
subset[j][indexB] = partBs[i]
|
||||
subset[j][-1] += 1
|
||||
subset[j][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
elif found == 2: # if found 2 and disjoint, merge them
|
||||
j1, j2 = subset_idx
|
||||
membership = (
|
||||
(subset[j1] >= 0).astype(int)
|
||||
+ (subset[j2] >= 0).astype(int)
|
||||
)[:-2]
|
||||
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
||||
subset[j1][:-2] += subset[j2][:-2] + 1
|
||||
subset[j1][-2:] += subset[j2][-2:]
|
||||
subset[j1][-2] += connection_all[k][i][2]
|
||||
subset = np.delete(subset, j2, 0)
|
||||
else: # as like found == 1
|
||||
subset[j1][indexB] = partBs[i]
|
||||
subset[j1][-1] += 1
|
||||
subset[j1][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
|
||||
# if find no partA in the subset, create a new subset
|
||||
elif not found and k < 17:
|
||||
row = -1 * np.ones(20)
|
||||
row[indexA] = partAs[i]
|
||||
row[indexB] = partBs[i]
|
||||
row[-1] = 2
|
||||
row[-2] = (
|
||||
sum(
|
||||
candidate[
|
||||
connection_all[k][i, :2].astype(int), 2
|
||||
]
|
||||
)
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
subset = np.vstack([subset, row])
|
||||
# delete some rows of subset which has few parts occur
|
||||
deleteIdx = []
|
||||
for i in range(len(subset)):
|
||||
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
||||
deleteIdx.append(i)
|
||||
subset = np.delete(subset, deleteIdx, axis=0)
|
||||
|
||||
# candidate: x, y, score, id
|
||||
return candidate, subset
|
||||
205
apps/stable_diffusion/src/utils/stencils/openpose/hand.py
Normal file
205
apps/stable_diffusion/src/utils/stencils/openpose/hand.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from skimage.measure import label
|
||||
from collections import OrderedDict
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import (
|
||||
make_layers,
|
||||
transfer,
|
||||
padRightDownCorner,
|
||||
npmax,
|
||||
)
|
||||
|
||||
|
||||
class HandPoseModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(HandPoseModel, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv6_2_CPM",
|
||||
"Mconv7_stage2",
|
||||
"Mconv7_stage3",
|
||||
"Mconv7_stage4",
|
||||
"Mconv7_stage5",
|
||||
"Mconv7_stage6",
|
||||
]
|
||||
# stage 1
|
||||
block1_0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3", [512, 512, 3, 1, 1]),
|
||||
("conv4_4", [512, 512, 3, 1, 1]),
|
||||
("conv5_1", [512, 512, 3, 1, 1]),
|
||||
("conv5_2", [512, 512, 3, 1, 1]),
|
||||
("conv5_3_CPM", [512, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv6_1_CPM", [128, 512, 1, 1, 0]),
|
||||
("conv6_2_CPM", [512, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks = {}
|
||||
blocks["block1_0"] = block1_0
|
||||
blocks["block1_1"] = block1_1
|
||||
|
||||
# stage 2-6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_0 = blocks["block1_0"]
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2 = blocks["block2"]
|
||||
self.model3 = blocks["block3"]
|
||||
self.model4 = blocks["block4"]
|
||||
self.model5 = blocks["block5"]
|
||||
self.model6 = blocks["block6"]
|
||||
|
||||
def forward(self, x):
|
||||
out1_0 = self.model1_0(x)
|
||||
out1_1 = self.model1_1(out1_0)
|
||||
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
||||
out_stage2 = self.model2(concat_stage2)
|
||||
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
||||
out_stage3 = self.model3(concat_stage3)
|
||||
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
||||
out_stage4 = self.model4(concat_stage4)
|
||||
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
||||
out_stage5 = self.model5(concat_stage5)
|
||||
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
||||
out_stage6 = self.model6(concat_stage6)
|
||||
return out_stage6
|
||||
|
||||
|
||||
class Hand(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = HandPoseModel()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
model_dict = transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, oriImg):
|
||||
scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
# scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
|
||||
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = cv2.resize(
|
||||
oriImg,
|
||||
(0, 0),
|
||||
fx=scale,
|
||||
fy=scale,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
imageToTest_padded, pad = padRightDownCorner(
|
||||
imageToTest, stride, padValue
|
||||
)
|
||||
im = (
|
||||
np.transpose(
|
||||
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
|
||||
(3, 2, 0, 1),
|
||||
)
|
||||
/ 256
|
||||
- 0.5
|
||||
)
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
if torch.cuda.is_available():
|
||||
data = data.cuda()
|
||||
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
output = self.model(data).cpu().numpy()
|
||||
# output = self.model(data).numpy()q
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(
|
||||
np.squeeze(output), (1, 2, 0)
|
||||
) # output 1 is heatmaps
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(0, 0),
|
||||
fx=stride,
|
||||
fy=stride,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
heatmap = heatmap[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(oriImg.shape[1], oriImg.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
heatmap_avg += heatmap / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
for part in range(21):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
||||
# 全部小于阈值
|
||||
if np.sum(binary) == 0:
|
||||
all_peaks.append([0, 0])
|
||||
continue
|
||||
label_img, label_numbers = label(
|
||||
binary, return_num=True, connectivity=binary.ndim
|
||||
)
|
||||
max_index = (
|
||||
np.argmax(
|
||||
[
|
||||
np.sum(map_ori[label_img == i])
|
||||
for i in range(1, label_numbers + 1)
|
||||
]
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
label_img[label_img != max_index] = 0
|
||||
map_ori[label_img == 0] = 0
|
||||
|
||||
y, x = npmax(map_ori)
|
||||
all_peaks.append([x, y])
|
||||
return np.array(all_peaks)
|
||||
@@ -0,0 +1,272 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import cv2
|
||||
from collections import OrderedDict
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_layers(block, no_relu_layers):
|
||||
layers = []
|
||||
for layer_name, v in block.items():
|
||||
if "pool" in layer_name:
|
||||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
|
||||
layers.append((layer_name, layer))
|
||||
else:
|
||||
conv2d = nn.Conv2d(
|
||||
in_channels=v[0],
|
||||
out_channels=v[1],
|
||||
kernel_size=v[2],
|
||||
stride=v[3],
|
||||
padding=v[4],
|
||||
)
|
||||
layers.append((layer_name, conv2d))
|
||||
if layer_name not in no_relu_layers:
|
||||
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
|
||||
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
|
||||
def padRightDownCorner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
# transfer caffe model to pytorch which will match the layer name
|
||||
def transfer(model, model_weights):
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights[
|
||||
".".join(weights_name.split(".")[1:])
|
||||
]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
# draw the body keypoint and lims
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
stickwidth = 4
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
|
||||
colors = [
|
||||
[255, 0, 0],
|
||||
[255, 85, 0],
|
||||
[255, 170, 0],
|
||||
[255, 255, 0],
|
||||
[170, 255, 0],
|
||||
[85, 255, 0],
|
||||
[0, 255, 0],
|
||||
[0, 255, 85],
|
||||
[0, 255, 170],
|
||||
[0, 255, 255],
|
||||
[0, 170, 255],
|
||||
[0, 85, 255],
|
||||
[0, 0, 255],
|
||||
[85, 0, 255],
|
||||
[170, 0, 255],
|
||||
[255, 0, 255],
|
||||
[255, 0, 170],
|
||||
[255, 0, 85],
|
||||
]
|
||||
for i in range(18):
|
||||
for n in range(len(subset)):
|
||||
index = int(subset[n][i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
for i in range(17):
|
||||
for n in range(len(subset)):
|
||||
index = subset[n][np.array(limbSeq[i]) - 1]
|
||||
if -1 in index:
|
||||
continue
|
||||
cur_canvas = canvas.copy()
|
||||
Y = candidate[index.astype(int), 0]
|
||||
X = candidate[index.astype(int), 1]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mY), int(mX)),
|
||||
(int(length / 2), stickwidth),
|
||||
int(angle),
|
||||
0,
|
||||
360,
|
||||
1,
|
||||
)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
return canvas
|
||||
|
||||
|
||||
# image drawed by opencv is not good.
|
||||
def draw_handpose(canvas, all_hand_peaks, show_number=False):
|
||||
edges = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[0, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[0, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[0, 17],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
]
|
||||
|
||||
for peaks in all_hand_peaks:
|
||||
for ie, e in enumerate(edges):
|
||||
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
|
||||
x1, y1 = peaks[e[0]]
|
||||
x2, y2 = peaks[e[1]]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb(
|
||||
[ie / float(len(edges)), 1.0, 1.0]
|
||||
)
|
||||
* 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
for i, keyponit in enumerate(peaks):
|
||||
x, y = keyponit
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
if show_number:
|
||||
cv2.putText(
|
||||
canvas,
|
||||
str(i),
|
||||
(x, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.3,
|
||||
(0, 0, 0),
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(candidate, subset, oriImg):
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
for person in subset.astype(int):
|
||||
# if any of three not detected
|
||||
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
||||
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
||||
if not (has_left or has_right):
|
||||
continue
|
||||
hands = []
|
||||
# left hand
|
||||
if has_left:
|
||||
left_shoulder_index, left_elbow_index, left_wrist_index = person[
|
||||
[5, 6, 7]
|
||||
]
|
||||
x1, y1 = candidate[left_shoulder_index][:2]
|
||||
x2, y2 = candidate[left_elbow_index][:2]
|
||||
x3, y3 = candidate[left_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, True])
|
||||
# right hand
|
||||
if has_right:
|
||||
(
|
||||
right_shoulder_index,
|
||||
right_elbow_index,
|
||||
right_wrist_index,
|
||||
) = person[[2, 3, 4]]
|
||||
x1, y1 = candidate[right_shoulder_index][:2]
|
||||
x2, y2 = candidate[right_elbow_index][:2]
|
||||
x3, y3 = candidate[right_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, False])
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
if x < 0:
|
||||
x = 0
|
||||
if y < 0:
|
||||
y = 0
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append([int(x), int(y), int(width), is_left])
|
||||
|
||||
"""
|
||||
return value: [[x, y, w, True if left hand else False]].
|
||||
width=height since the network require squared input.
|
||||
x, y is the coordinate of top left
|
||||
"""
|
||||
return detect_result
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return (i,)
|
||||
@@ -1,8 +1,10 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
|
||||
@@ -26,23 +28,6 @@ def HWC3(x):
|
||||
return y
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(
|
||||
input_image,
|
||||
(W, H),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt=1
|
||||
):
|
||||
@@ -125,7 +110,13 @@ def controlnet_hint_conversion(
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print("Detecting edge with canny")
|
||||
controlnet_hint = hint_canny(image, width)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print("Detecting human pose")
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
@@ -134,22 +125,62 @@ def controlnet_hint_conversion(
|
||||
return controlnet_hint
|
||||
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/sd-controlnet-canny",
|
||||
"depth": "lllyasviel/sd-controlnet-depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/sd-controlnet-mlsd",
|
||||
"normal": "lllyasviel/sd-controlnet-normal",
|
||||
"openpose": "lllyasviel/sd-controlnet-openpose",
|
||||
"scribble": "lllyasviel/sd-controlnet-scribble",
|
||||
"seg": "lllyasviel/sd-controlnet-seg",
|
||||
}
|
||||
|
||||
|
||||
def get_stencil_model_id(use_stencil):
|
||||
if use_stencil in stencil_to_model_id_map:
|
||||
return stencil_to_model_id_map[use_stencil]
|
||||
return None
|
||||
|
||||
|
||||
# Stencil 1. Canny
|
||||
def hint_canny(
|
||||
image: Image.Image,
|
||||
width=512,
|
||||
height=512,
|
||||
low_threshold=100,
|
||||
high_threshold=200,
|
||||
):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
image_resolution = width
|
||||
|
||||
img = resize_image(HWC3(input_image), image_resolution)
|
||||
|
||||
if not "canny" in stencil:
|
||||
stencil["canny"] = CannyDetector()
|
||||
detected_map = stencil["canny"](img, low_threshold, high_threshold)
|
||||
detected_map = stencil["canny"](
|
||||
input_image, low_threshold, high_threshold
|
||||
)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 2. OpenPose.
|
||||
def hint_openpose(
|
||||
image: Image.Image,
|
||||
):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "openpose" in stencil:
|
||||
stencil["openpose"] = OpenposeDetector()
|
||||
|
||||
detected_map, _ = stencil["openpose"](input_image)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 3. Scribble.
|
||||
def hint_scribble(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
return detected_map
|
||||
|
||||
@@ -8,6 +8,9 @@ from csv import DictWriter
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -20,16 +23,12 @@ from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
device = args.device.split("://", 1)[0]
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
return extended_name
|
||||
|
||||
@@ -81,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
mlir_model, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
@@ -94,33 +93,55 @@ def compile_through_fx(
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
base_model_id=None,
|
||||
):
|
||||
from shark.parser import shark_args
|
||||
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
(
|
||||
mlir_module,
|
||||
func_name,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, model_name, base_model_id
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
|
||||
if generate_vmfb:
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
@@ -248,8 +269,8 @@ def set_init_device_flags():
|
||||
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height != 512
|
||||
or args.width != 512
|
||||
or args.height not in [512, 768]
|
||||
or args.width not in [512, 768]
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
@@ -283,6 +304,20 @@ def set_init_device_flags():
|
||||
]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height == 768
|
||||
and args.width == 768
|
||||
and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
else:
|
||||
@@ -438,7 +473,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
@@ -448,6 +483,115 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
print("Loading complete")
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
@@ -563,7 +707,7 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
@@ -603,7 +747,7 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
with open(csv_path, "a", encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
@@ -613,3 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
|
||||
|
||||
return text_output
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
|
||||
import gradio as gr
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import get_custom_model_path
|
||||
|
||||
# clear all gradio tmp images from the last session
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create the custom model folder if it doesn't already exist
|
||||
get_custom_model_path().mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
@@ -33,23 +38,37 @@ from apps.stable_diffusion.web.ui import (
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
img2img_web,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
inpaint_web,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
outpaint_web,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
upscaler_web,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
global_obj._init()
|
||||
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
@@ -74,6 +93,12 @@ with gr.Blocks(
|
||||
inpaint_web.render()
|
||||
with gr.TabItem(label="Outpainting", id=3):
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
|
||||
with gr.Tabs(visible=False) as experimental_tabs:
|
||||
with gr.TabItem(label="LoRA Training", id=5):
|
||||
lora_train_web.render()
|
||||
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
@@ -93,6 +118,12 @@ with gr.Blocks(
|
||||
[txt2img_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
txt2img_sendto_upscaler,
|
||||
4,
|
||||
[txt2img_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
img2img_sendto_inpaint,
|
||||
2,
|
||||
@@ -105,6 +136,12 @@ with gr.Blocks(
|
||||
[img2img_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
img2img_sendto_upscaler,
|
||||
4,
|
||||
[img2img_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
inpaint_sendto_img2img,
|
||||
1,
|
||||
@@ -117,6 +154,12 @@ with gr.Blocks(
|
||||
[inpaint_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
inpaint_sendto_upscaler,
|
||||
4,
|
||||
[inpaint_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
outpaint_sendto_img2img,
|
||||
1,
|
||||
@@ -129,6 +172,30 @@ with gr.Blocks(
|
||||
[outpaint_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
outpaint_sendto_upscaler,
|
||||
4,
|
||||
[outpaint_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
upscaler_sendto_img2img,
|
||||
1,
|
||||
[upscaler_gallery],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
upscaler_sendto_inpaint,
|
||||
2,
|
||||
[upscaler_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
upscaler_sendto_outpaint,
|
||||
3,
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
|
||||
|
||||
sd_web.queue()
|
||||
|
||||
@@ -4,6 +4,7 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_web,
|
||||
@@ -11,6 +12,7 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_init_image,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_web,
|
||||
@@ -18,6 +20,7 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_init_image,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_web,
|
||||
@@ -25,4 +28,14 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_init_image,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_web,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
|
||||
@@ -1,153 +1,106 @@
|
||||
|
||||
/* Overwrite the Gradio default theme with their .dark theme declarations */
|
||||
/*
|
||||
Apply Gradio dark theme to the default Gradio theme.
|
||||
Procedure to upgrade the dark theme:
|
||||
- Using your browser, visit http://localhost:8080/?__theme=dark
|
||||
- Open your browser inspector, search for the .dark css class
|
||||
- Copy .dark class declarations, apply them here into :root
|
||||
*/
|
||||
|
||||
:root {
|
||||
--color-focus-primary: var(--color-grey-700);
|
||||
--color-focus-secondary: var(--color-grey-600);
|
||||
--color-focus-ring: rgb(55 65 81);
|
||||
--color-background-primary: var(--color-grey-950);
|
||||
--color-background-secondary: var(--color-grey-900);
|
||||
--color-background-tertiary: var(--color-grey-800);
|
||||
--color-text-body: var(--color-grey-100);
|
||||
--color-text-label: var(--color-grey-200);
|
||||
--color-text-placeholder: var(--color-grey);
|
||||
--color-text-subdued: var(--color-grey-400);
|
||||
--color-text-link-base: var(--color-blue-500);
|
||||
--color-text-link-hover: var(--color-blue-400);
|
||||
--color-text-link-visited: var(--color-blue-600);
|
||||
--color-text-link-active: var(--color-blue-500);
|
||||
--color-text-code-background: var(--color-grey-800);
|
||||
--color-text-code-border: color.border-primary;
|
||||
--color-border-primary: var(--color-grey-700);
|
||||
--color-border-secondary: var(--color-grey-600);
|
||||
--color-border-highlight: var(--color-accent-base);
|
||||
--color-accent-base: var(--color-orange-500);
|
||||
--color-accent-light: var(--color-orange-300);
|
||||
--color-accent-dark: var(--color-orange-700);
|
||||
--color-functional-error-base: var(--color-red-400);
|
||||
--color-functional-error-subdued: var(--color-red-300);
|
||||
--color-functional-error-background: var(--color-background-primary);
|
||||
--color-functional-info-base: var(--color-yellow);
|
||||
--color-functional-info-subdued: var(--color-yellow-300);
|
||||
--color-functional-success-base: var(--color-green);
|
||||
--color-functional-success-subdued: var(--color-green-300);
|
||||
--shadow-spread: 2px;
|
||||
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
|
||||
--api-pill-background: var(--color-orange-400);
|
||||
--api-pill-border: var(--color-orange-600);
|
||||
--api-pill-text: var(--color-orange-900);
|
||||
--block-border-color: var(--color-border-primary);
|
||||
--block-background: var(--color-background-tertiary);
|
||||
--uploadable-border-color-hover: var(--color-border-primary);
|
||||
--uploadable-border-color-loaded: var(--color-functional-success);
|
||||
--uploadable-text-color: var(--color-text-subdued);
|
||||
--block_label-border-color: var(--color-border-primary);
|
||||
--block_label-icon-color: var(--color-text-label);
|
||||
--block_label-shadow: var(--shadow-drop);
|
||||
--block_label-background: var(--color-background-secondary);
|
||||
--icon_button-icon-color-base: var(--color-text-label);
|
||||
--icon_button-icon-color-hover: var(--color-text-label);
|
||||
--icon_button-background-base: var(--color-background-primary);
|
||||
--icon_button-background-hover: var(--color-background-primary);
|
||||
--icon_button-border-color-base: var(--color-background-primary);
|
||||
--icon_button-border-color-hover: var(--color-border-secondary);
|
||||
--input-text-color: var(--color-text-body);
|
||||
--input-border-color-base: var(--color-border-primary);
|
||||
--input-border-color-hover: var(--color-border-primary);
|
||||
--input-border-color-focus: var(--color-border-primary);
|
||||
--input-background-base: var(--color-background-tertiary);
|
||||
--input-background-hover: var(--color-background-tertiary);
|
||||
--input-background-focus: var(--color-background-tertiary);
|
||||
--input-shadow: var(--shadow-inset);
|
||||
--checkbox-border-color-base: var(--color-border-primary);
|
||||
--checkbox-border-color-hover: var(--color-focus-primary);
|
||||
--checkbox-border-color-focus: var(--color-blue-500);
|
||||
--checkbox-background-base: var(--color-background-primary);
|
||||
--checkbox-background-hover: var(--color-background-primary);
|
||||
--checkbox-background-focus: var(--color-background-primary);
|
||||
--checkbox-background-selected: var(--color-blue-600);
|
||||
--checkbox-label-border-color-base: var(--color-border-primary);
|
||||
--checkbox-label-border-color-hover: var(--color-border-primary);
|
||||
--checkbox-label-border-color-focus: var(--color-border-secondary);
|
||||
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
|
||||
--form-seperator-color: var(--color-border-primary);
|
||||
--button-primary-border-color-base: var(--color-orange-600);
|
||||
--button-primary-border-color-hover: var(--color-orange-600);
|
||||
--button-primary-border-color-focus: var(--color-orange-600);
|
||||
--button-primary-text-color-base: white;
|
||||
--button-primary-text-color-hover: white;
|
||||
--button-primary-text-color-focus: white;
|
||||
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
|
||||
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
|
||||
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
|
||||
--button-secondary-border-color-base: var(--color-grey-600);
|
||||
--button-secondary-border-color-hover: var(--color-grey-600);
|
||||
--button-secondary-border-color-focus: var(--color-grey-600);
|
||||
--button-secondary-text-color-base: white;
|
||||
--button-secondary-text-color-hover: white;
|
||||
--button-secondary-text-color-focus: white;
|
||||
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
|
||||
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
|
||||
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
|
||||
--button-cancel-border-color-base: var(--color-red-600);
|
||||
--button-cancel-border-color-hover: var(--color-red-600);
|
||||
--button-cancel-border-color-focus: var(--color-red-600);
|
||||
--button-cancel-text-color-base: white;
|
||||
--button-cancel-text-color-hover: white;
|
||||
--button-cancel-text-color-focus: white;
|
||||
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
|
||||
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
|
||||
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
|
||||
--button-plain-border-color-base: var(--color-grey-600);
|
||||
--button-plain-border-color-hover: var(--color-grey-500);
|
||||
--button-plain-border-color-focus: var(--color-grey-500);
|
||||
--button-plain-text-color-base: var(--color-text-body);
|
||||
--button-plain-text-color-hover: var(--color-text-body);
|
||||
--button-plain-text-color-focus: var(--color-text-body);
|
||||
--button-plain-background-base: var(--color-grey-700);
|
||||
--button-plain-background-hover: var(--color-grey-700);
|
||||
--button-plain-background-focus: var(--color-grey-700);
|
||||
--gallery-label-background-base: var(--color-grey-50);
|
||||
--gallery-label-background-hover: var(--color-grey-50);
|
||||
--gallery-label-border-color-base: var(--color-border-primary);
|
||||
--gallery-label-border-color-hover: var(--color-border-primary);
|
||||
--gallery-thumb-background-base: var(--color-grey-900);
|
||||
--gallery-thumb-background-hover: var(--color-grey-900);
|
||||
--gallery-thumb-border-color-base: var(--color-border-primary);
|
||||
--gallery-thumb-border-color-hover: var(--color-accent-base);
|
||||
--gallery-thumb-border-color-focus: var(--color-blue-500);
|
||||
--gallery-thumb-border-color-selected: var(--color-accent-base);
|
||||
--chatbot-border-border-color-base: transparent;
|
||||
--chatbot-border-border-color-latest: transparent;
|
||||
--chatbot-user-background-base: ;
|
||||
--chatbot-user-background-latest: ;
|
||||
--chatbot-user-text-color-base: white;
|
||||
--chatbot-user-text-color-latest: white;
|
||||
--chatbot-bot-background-base: ;
|
||||
--chatbot-bot-background-latest: ;
|
||||
--chatbot-bot-text-color-base: white;
|
||||
--chatbot-bot-text-color-latest: white;
|
||||
--label-gradient-from: var(--color-orange-400);
|
||||
--label-gradient-to: var(--color-orange-600);
|
||||
--table-odd-background: var(--color-grey-900);
|
||||
--table-even-background: var(--color-grey-950);
|
||||
--table-background-edit: transparent;
|
||||
--dataset-gallery-background-base: var(--color-background-primary);
|
||||
--dataset-gallery-background-hover: var(--color-grey-800);
|
||||
--dataset-dataframe-border-base: var(--color-border-primary);
|
||||
--dataset-dataframe-border-hover: var(--color-border-secondary);
|
||||
--dataset-table-background-base: transparent;
|
||||
--dataset-table-background-hover: var(--color-grey-700);
|
||||
--dataset-table-border-base: var(--color-grey-800);
|
||||
--dataset-table-border-hover: var(--color-grey-800);
|
||||
--body-background-fill: var(--background-fill-primary);
|
||||
--body-text-color: var(--neutral-100);
|
||||
--color-accent-soft: var(--neutral-700);
|
||||
--background-fill-primary: var(--neutral-950);
|
||||
--background-fill-secondary: var(--neutral-900);
|
||||
--border-color-accent: var(--neutral-600);
|
||||
--border-color-primary: var(--neutral-700);
|
||||
--link-text-color-active: var(--secondary-500);
|
||||
--link-text-color: var(--secondary-500);
|
||||
--link-text-color-hover: var(--secondary-400);
|
||||
--link-text-color-visited: var(--secondary-600);
|
||||
--body-text-color-subdued: var(--neutral-400);
|
||||
--shadow-spread: 1px;
|
||||
--block-background-fill: var(--neutral-800);
|
||||
--block-border-color: var(--border-color-primary);
|
||||
--block_border_width: None;
|
||||
--block-info-text-color: var(--body-text-color-subdued);
|
||||
--block-label-background-fill: var(--background-fill-secondary);
|
||||
--block-label-border-color: var(--border-color-primary);
|
||||
--block_label_border_width: None;
|
||||
--block-label-text-color: var(--neutral-200);
|
||||
--block_shadow: None;
|
||||
--block_title_background_fill: None;
|
||||
--block_title_border_color: None;
|
||||
--block_title_border_width: None;
|
||||
--block-title-text-color: var(--neutral-200);
|
||||
--panel-background-fill: var(--background-fill-secondary);
|
||||
--panel-border-color: var(--border-color-primary);
|
||||
--panel_border_width: None;
|
||||
--checkbox-background-color: var(--neutral-800);
|
||||
--checkbox-background-color-focus: var(--checkbox-background-color);
|
||||
--checkbox-background-color-hover: var(--checkbox-background-color);
|
||||
--checkbox-background-color-selected: var(--secondary-600);
|
||||
--checkbox-border-color: var(--neutral-700);
|
||||
--checkbox-border-color-focus: var(--secondary-500);
|
||||
--checkbox-border-color-hover: var(--neutral-600);
|
||||
--checkbox-border-color-selected: var(--secondary-600);
|
||||
--checkbox-border-width: var(--input-border-width);
|
||||
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
|
||||
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
|
||||
--checkbox-label-border-color: var(--border-color-primary);
|
||||
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
|
||||
--checkbox-label-border-width: var(--input-border-width);
|
||||
--checkbox-label-text-color: var(--body-text-color);
|
||||
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
|
||||
--error-background-fill: var(--background-fill-primary);
|
||||
--error-border-color: var(--border-color-primary);
|
||||
--error_border_width: None;
|
||||
--error-text-color: #ef4444;
|
||||
--input-background-fill: var(--neutral-800);
|
||||
--input-background-fill-focus: var(--secondary-600);
|
||||
--input-background-fill-hover: var(--input-background-fill);
|
||||
--input-border-color: var(--border-color-primary);
|
||||
--input-border-color-focus: var(--neutral-700);
|
||||
--input-border-color-hover: var(--input-border-color);
|
||||
--input_border_width: None;
|
||||
--input-placeholder-color: var(--neutral-500);
|
||||
--input_shadow: None;
|
||||
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
|
||||
--loader_color: None;
|
||||
--slider_color: None;
|
||||
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
|
||||
--table-border-color: var(--neutral-700);
|
||||
--table-even-background-fill: var(--neutral-950);
|
||||
--table-odd-background-fill: var(--neutral-900);
|
||||
--table-row-focus: var(--color-accent-soft);
|
||||
--button-border-width: var(--input-border-width);
|
||||
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
|
||||
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
|
||||
--button-cancel-border-color: #dc2626;
|
||||
--button-cancel-border-color-hover: var(--button-cancel-border-color);
|
||||
--button-cancel-text-color: white;
|
||||
--button-cancel-text-color-hover: var(--button-cancel-text-color);
|
||||
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
|
||||
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
|
||||
--button-primary-border-color: var(--primary-500);
|
||||
--button-primary-border-color-hover: var(--button-primary-border-color);
|
||||
--button-primary-text-color: white;
|
||||
--button-primary-text-color-hover: var(--button-primary-text-color);
|
||||
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
|
||||
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
|
||||
--button-secondary-border-color: var(--neutral-600);
|
||||
--button-secondary-border-color-hover: var(--button-secondary-border-color);
|
||||
--button-secondary-text-color: white;
|
||||
--button-secondary-text-color-hover: var(--button-secondary-text-color);
|
||||
--block-border-width: 1px;
|
||||
--block-label-border-width: 1px;
|
||||
--form-gap-width: 1px;
|
||||
--error-border-width: 1px;
|
||||
--input-border-width: 1px;
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
@@ -185,7 +138,7 @@ body {
|
||||
}
|
||||
|
||||
#prompt_box textarea, #negative_prompt_box textarea {
|
||||
background-color: var(--color-background-primary) !important;
|
||||
background-color: var(--background-fill-primary) !important;
|
||||
}
|
||||
|
||||
#prompt_examples {
|
||||
@@ -197,7 +150,6 @@ body {
|
||||
}
|
||||
|
||||
#ui_body {
|
||||
background-color: var(--color-background-secondary) !important;
|
||||
padding: var(--size-2) !important;
|
||||
border-radius: 0.5em !important;
|
||||
}
|
||||
@@ -219,3 +171,29 @@ footer {
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* Import Png info box */
|
||||
#txt2img_prompt_image .fixed-height {
|
||||
height: var(--size-32);
|
||||
}
|
||||
|
||||
/* Hide "remove buttons" from ui dropdowns */
|
||||
#custom_model .token-remove.remove-all,
|
||||
#lora_weights .token-remove.remove-all,
|
||||
#scheduler .token-remove.remove-all,
|
||||
#device .token-remove.remove-all,
|
||||
#stencil_model .token-remove.remove-all {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Hide selected items from ui dropdowns */
|
||||
#custom_model .options .item .inner-item,
|
||||
#scheduler .options .item .inner-item,
|
||||
#device .options .item .inner-item,
|
||||
#stencil_model .options .item .inner-item {
|
||||
display:none;
|
||||
}
|
||||
|
||||
/* Hide the download icon from the nod logo */
|
||||
#top_logo .download {
|
||||
display: none;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import img2img_inf
|
||||
@@ -9,6 +7,11 @@ from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,34 +30,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -82,21 +69,33 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
with gr.Row():
|
||||
use_stencil = gr.Dropdown(
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny"],
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
choices=scheduler_list,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -143,24 +142,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Strength",
|
||||
label="Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -170,24 +171,28 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -213,6 +218,9 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
value="SendTo Outpaint"
|
||||
)
|
||||
img2img_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=img2img_inf,
|
||||
@@ -237,11 +245,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
use_stencil,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import inpaint_inf
|
||||
@@ -9,6 +7,11 @@ from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,30 +30,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -78,17 +69,28 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
type="pil",
|
||||
).style(height=350)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
choices=scheduler_list,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -145,21 +147,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -169,24 +173,28 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -212,6 +220,9 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
value="SendTo Outpaint"
|
||||
)
|
||||
inpaint_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=inpaint_inf,
|
||||
@@ -236,11 +247,17 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
205
apps/stable_diffusion/web/ui/lora_train_ui.py
Normal file
205
apps/stable_diffusion/web/ui/lora_train_ui.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import lora_train
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="image_dir_box_outer"):
|
||||
training_images_dir = gr.Textbox(
|
||||
label="ImageDirectory",
|
||||
value=args.training_images_dir,
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list_txt2img,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 768, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
2000,
|
||||
value=args.training_steps,
|
||||
step=1,
|
||||
label="Training Steps",
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
train_lora = gr.Button("Train LoRA")
|
||||
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
lora_save_dir = (
|
||||
args.lora_save_dir if args.lora_save_dir else Path.cwd()
|
||||
)
|
||||
lora_save_dir = Path(lora_save_dir, "lora")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Lora at",
|
||||
value=lora_save_dir,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=lora_train,
|
||||
inputs=[
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
training_images_dir,
|
||||
output_loc,
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
train_click = train_lora.click(**kwargs)
|
||||
stop_batch.click(fn=None, cancels=[prompt_submit, train_click])
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import outpaint_inf
|
||||
@@ -9,6 +7,11 @@ from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,30 +30,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_paint_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
@@ -75,17 +66,28 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
choices=scheduler_list,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -164,21 +166,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
1, 100, value=20, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -188,24 +192,28 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -229,6 +237,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
outpaint_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=outpaint_inf,
|
||||
@@ -256,11 +267,17 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
@@ -9,9 +7,13 @@ from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
@@ -27,39 +29,33 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
tool="None",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
@@ -74,21 +70,28 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
choices=scheduler_list_txt2img,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -138,40 +141,47 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -205,6 +215,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
txt2img_sendto_outpaint = gr.Button(
|
||||
value="SendTo Outpaint"
|
||||
)
|
||||
txt2img_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
@@ -226,11 +239,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.web.utils.png_metadata import (
|
||||
import_png_metadata,
|
||||
)
|
||||
|
||||
png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
png_info_img,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
],
|
||||
)
|
||||
|
||||
256
apps/stable_diffusion/web/ui/upscaler_ui.py
Normal file
256
apps/stable_diffusion/web/ui/upscaler_ui.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import upscaler_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_upscaler_models,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_upscaler_models,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
upscaler_init_image = gr.Image(
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
128,
|
||||
512,
|
||||
value=args.height,
|
||||
step=128,
|
||||
label="Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
128,
|
||||
512,
|
||||
value=args.width,
|
||||
step=128,
|
||||
label="Width",
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
noise_level = gr.Slider(
|
||||
0,
|
||||
100,
|
||||
value=args.noise_level,
|
||||
step=1,
|
||||
label="Noise Level",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
upscaler_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
upscaler_sendto_outpaint = gr.Button(
|
||||
value="SendTo Outpaint"
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=upscaler_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
upscaler_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
)
|
||||
@@ -1,6 +1,69 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
mode: str
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
|
||||
scheduler_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
]
|
||||
scheduler_list_txt2img = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
predefined_models = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
predefined_paint_models = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
]
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
]
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -11,5 +74,56 @@ def resource_path(relative_path):
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_custom_model_path(model="models"):
|
||||
match model:
|
||||
case "models":
|
||||
return Path(Path.cwd(), "models")
|
||||
case "vae":
|
||||
return Path(Path.cwd(), "models/vae")
|
||||
case "lora":
|
||||
return Path(Path.cwd(), "models/lora")
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_model_pathfile(custom_model_name, model="models"):
|
||||
return os.path.join(get_custom_model_path(model), custom_model_name)
|
||||
|
||||
|
||||
def get_custom_model_files(model="models"):
|
||||
ckpt_files = []
|
||||
file_types = custom_model_filetypes
|
||||
if model == "lora":
|
||||
file_types = custom_model_filetypes + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(
|
||||
os.path.join(get_custom_model_path(model), extn)
|
||||
)
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
return use_weight
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
global_obj.set_sd_status(SD_STATE_CANCEL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
68
apps/stable_diffusion/web/utils/global_obj.py
Normal file
68
apps/stable_diffusion/web/utils/global_obj.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import gc
|
||||
|
||||
|
||||
"""
|
||||
The global objects include SD pipeline and config.
|
||||
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
|
||||
Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
|
||||
|
||||
def _init():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
_schedulers = None
|
||||
|
||||
|
||||
def set_sd_obj(value):
|
||||
global _sd_obj
|
||||
_sd_obj = value
|
||||
|
||||
|
||||
def set_sd_scheduler(key):
|
||||
global _sd_obj
|
||||
_sd_obj.scheduler = _schedulers[key]
|
||||
|
||||
|
||||
def set_sd_status(value):
|
||||
global _sd_obj
|
||||
_sd_obj.status = value
|
||||
|
||||
|
||||
def set_cfg_obj(value):
|
||||
global _config_obj
|
||||
_config_obj = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
|
||||
|
||||
def get_sd_obj():
|
||||
return _sd_obj
|
||||
|
||||
|
||||
def get_sd_status():
|
||||
return _sd_obj.status
|
||||
|
||||
|
||||
def get_cfg_obj():
|
||||
return _config_obj
|
||||
|
||||
|
||||
def get_scheduler(key):
|
||||
return _schedulers[key]
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global _sd_obj
|
||||
global _config_obj
|
||||
del _sd_obj
|
||||
del _config_obj
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_config_obj = None
|
||||
148
apps/stable_diffusion/web/utils/png_metadata.py
Normal file
148
apps/stable_diffusion/web/utils/png_metadata.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
scheduler_list_txt2img,
|
||||
predefined_models,
|
||||
)
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
res = {}
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ""
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if line.startswith("Negative prompt:"):
|
||||
done_with_prompt = True
|
||||
line = line[16:].strip()
|
||||
|
||||
if done_with_prompt:
|
||||
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
||||
else:
|
||||
prompt += ("" if prompt == "" else "\n") + line
|
||||
|
||||
res["Prompt"] = prompt
|
||||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k + "-1"] = m.group(1)
|
||||
res[k + "-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res[
|
||||
"Prompt"
|
||||
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def import_png_metadata(pil_data):
|
||||
try:
|
||||
png_info = pil_data.info["parameters"]
|
||||
metadata = parse_generation_parameters(png_info)
|
||||
png_hf_model_id = ""
|
||||
png_custom_model = ""
|
||||
|
||||
if "Model" in metadata:
|
||||
# Remove extension from model info
|
||||
if metadata["Model"].endswith(".safetensors") or metadata[
|
||||
"Model"
|
||||
].endswith(".ckpt"):
|
||||
metadata["Model"] = Path(metadata["Model"]).stem
|
||||
# Check for the model name match with one of the local ckpt or safetensors files
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".ckpt"
|
||||
if Path(
|
||||
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
|
||||
).is_file():
|
||||
png_custom_model = metadata["Model"] + ".safetensors"
|
||||
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
|
||||
if metadata["Model"] in predefined_models:
|
||||
png_custom_model = metadata["Model"]
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not png_custom_model and metadata["Model"].count("/"):
|
||||
png_hf_model_id = metadata["Model"]
|
||||
# No matching model was found
|
||||
if not png_custom_model and not png_hf_model_id:
|
||||
print(
|
||||
"Import PNG info: Unable to find a matching model for %s"
|
||||
% metadata["Model"]
|
||||
)
|
||||
|
||||
outputs = {
|
||||
png_info_img: None,
|
||||
negative_prompt: metadata["Negative prompt"],
|
||||
steps: int(metadata["Steps"]),
|
||||
guidance_scale: float(metadata["CFG scale"]),
|
||||
seed: int(metadata["Seed"]),
|
||||
width: float(metadata["Size-1"]),
|
||||
height: float(metadata["Size-2"]),
|
||||
}
|
||||
if "Model" in metadata and png_custom_model:
|
||||
outputs[custom_model] = png_custom_model
|
||||
outputs[hf_model_id] = ""
|
||||
if "Model" in metadata and png_hf_model_id:
|
||||
outputs[custom_model] = "None"
|
||||
outputs[hf_model_id] = png_hf_model_id
|
||||
if "Prompt" in metadata:
|
||||
outputs[prompt] = metadata["Prompt"]
|
||||
if "Sampler" in metadata:
|
||||
if metadata["Sampler"] in scheduler_list_txt2img:
|
||||
outputs[scheduler] = metadata["Sampler"]
|
||||
else:
|
||||
print(
|
||||
"Import PNG info: Unable to find a scheduler for %s"
|
||||
% metadata["Sampler"]
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
except Exception as ex:
|
||||
if pil_data and pil_data.info.get("parameters"):
|
||||
print("import_png_metadata failed with %s" % ex)
|
||||
pass
|
||||
|
||||
return {
|
||||
png_info_img: None,
|
||||
}
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -87,11 +87,22 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
]
|
||||
counter = 0
|
||||
for import_opt in import_options:
|
||||
for model_name in hf_model_names:
|
||||
if model_name in to_skip:
|
||||
continue
|
||||
for use_tune in tuned_options:
|
||||
if (
|
||||
model_name == "stabilityai/stable-diffusion-2-1"
|
||||
and use_tune == tuned_options[0]
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
model_name == "stabilityai/stable-diffusion-2-1-base"
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
@@ -174,9 +185,23 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
if "2_1_base" in model_name:
|
||||
with open(dumpfile_name, "r+") as f:
|
||||
output = f.readlines()
|
||||
print("\n".join(output))
|
||||
if model_name == "CompVis/stable-diffusion-v1-4":
|
||||
print("failed a known successful model.")
|
||||
exit(1)
|
||||
if os.name == "nt":
|
||||
counter += 1
|
||||
if counter % 2 == 0:
|
||||
extra_flags.append(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
else:
|
||||
if counter != 1:
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
|
||||
@@ -40,7 +40,7 @@ cmake --build build/
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
iree-compile --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
|
||||
@@ -6,15 +6,16 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
|
||||
pix2pix_file = Path(
|
||||
# Temorary workaround for transformers/__init__.py.
|
||||
path_to_tranformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py"
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
)
|
||||
if pix2pix_file.exists():
|
||||
print("Removing..%s", pix2pix_file)
|
||||
pix2pix_file.unlink()
|
||||
|
||||
if path_to_tranformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
with open(path_to_tranformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
@@ -42,3 +43,16 @@ for line in fileinput.input(path_to_skipfiles, inplace=True):
|
||||
print(line, end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
# For getting around scikit-image's packaging, laze_loader has had a patch merged but yet to be released.
|
||||
# Refer: https://github.com/scientific-python/lazy_loader
|
||||
path_to_lazy_loader = Path(get_python_lib() + "/lazy_loader/__init__.py")
|
||||
|
||||
for line in fileinput.input(path_to_lazy_loader, inplace=True):
|
||||
if 'stubfile = filename if filename.endswith("i")' in line:
|
||||
print(
|
||||
' stubfile = (filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi")',
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
@@ -10,3 +10,8 @@ requires = [
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
include = '\.pyi?$'
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@stable_stencil
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@main
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
@@ -24,6 +24,8 @@ altair
|
||||
omegaconf
|
||||
safetensors
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir==20230228.763 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
@@ -78,9 +78,9 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir==20230228.763 -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir==20230228.763 -f https://llvm.github.io/torch-mlir/package-index/
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
|
||||
18
shark/examples/shark_inference/llama/README.md
Normal file
18
shark/examples/shark_inference/llama/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# SHARK LLaMA
|
||||
|
||||
## TORCH-MLIR Version
|
||||
|
||||
```
|
||||
https://github.com/nod-ai/torch-mlir.git
|
||||
```
|
||||
Then check out the `complex` branch and `git submodule update --init` and then build with `.\build_tools\python_deploy\build_windows.ps1`
|
||||
|
||||
### Setup & Run
|
||||
```
|
||||
git clone https://github.com/nod-ai/llama.git
|
||||
```
|
||||
Then in this repository
|
||||
```
|
||||
pip install -e .
|
||||
python llama/shark_model.py
|
||||
```
|
||||
@@ -30,6 +30,7 @@ import sys
|
||||
import argparse
|
||||
import json
|
||||
import urllib.request
|
||||
import subprocess
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
@@ -310,7 +311,7 @@ def _prepare_attn_mask(
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}", destination_folder
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
|
||||
|
||||
@@ -634,6 +635,75 @@ def create_mlirs(destination_folder, model_name):
|
||||
)
|
||||
|
||||
|
||||
def run_large_model(
|
||||
token_count,
|
||||
recompile,
|
||||
model_path,
|
||||
prompt,
|
||||
device_list,
|
||||
script_path,
|
||||
device,
|
||||
):
|
||||
f = open(f"{model_path}/prompt.txt", "w+")
|
||||
f.write(prompt)
|
||||
f.close()
|
||||
for i in range(token_count):
|
||||
if i == 0:
|
||||
will_compile = recompile
|
||||
else:
|
||||
will_compile = False
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
prompt = f.read()
|
||||
f.close()
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"start",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
for i in range(config["n_layer"]):
|
||||
if device_list is not None:
|
||||
device_idx = str(device_list[i % len(device_list)])
|
||||
else:
|
||||
device_idx = "None"
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
str(i),
|
||||
str(will_compile),
|
||||
device,
|
||||
device_idx,
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"end",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
output = f.read()
|
||||
f.close()
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="Bloom-560m")
|
||||
parser.add_argument("-p", "--model_path")
|
||||
@@ -644,6 +714,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument("-m", "--model_name", default="bloom-560m")
|
||||
parser.add_argument("-cm", "--create_mlirs", default=False, type=bool)
|
||||
|
||||
parser.add_argument(
|
||||
"-lm", "--large_model_memory_efficient", default=False, type=bool
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
@@ -651,6 +726,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_mlirs and args.large_model_memory_efficient:
|
||||
print(
|
||||
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
|
||||
)
|
||||
|
||||
if not os.path.isdir(args.model_path):
|
||||
os.mkdir(args.model_path)
|
||||
|
||||
@@ -674,37 +754,59 @@ if __name__ == "__main__":
|
||||
if args.prompt is not None:
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device, replace=args.recompile, device_idx=args.device_list
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
if args.large_model_memory_efficient:
|
||||
f = open(f"{args.model_path}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
if args.prompt is not None:
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
self_path = os.path.dirname(os.path.abspath(__file__))
|
||||
script_path = os.path.join(self_path, "sharded_bloom_large_models.py")
|
||||
|
||||
if args.prompt is not None:
|
||||
run_large_model(
|
||||
args.token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
args.prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
run_large_model(
|
||||
token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print("Invalid integer entered. Using default value of 10")
|
||||
token_count = 10
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device,
|
||||
replace=args.recompile,
|
||||
device_idx=args.device_list,
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
for _ in range(token_count):
|
||||
if args.prompt is not None:
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
@@ -713,3 +815,28 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
for _ in range(token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
|
||||
381
shark/examples/shark_inference/sharded_bloom_large_models.py
Normal file
381
shark/examples/shark_inference/sharded_bloom_large_models.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import sys
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_dir = sys.argv[1]
|
||||
layer_name = sys.argv[2]
|
||||
will_compile = sys.argv[3]
|
||||
device = sys.argv[4]
|
||||
device_idx = sys.argv[5]
|
||||
prompt = sys.argv[6]
|
||||
|
||||
if device_idx.lower().strip() == "none":
|
||||
device_idx = None
|
||||
else:
|
||||
device_idx = int(device_idx)
|
||||
|
||||
if will_compile.lower().strip() == "true":
|
||||
will_compile = True
|
||||
else:
|
||||
will_compile = False
|
||||
|
||||
f = open(f"{working_dir}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
layers_initialized = False
|
||||
try:
|
||||
n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
n_embed = config["hidden_size"]
|
||||
vocab_size = config["vocab_size"]
|
||||
n_layer = config["n_layer"]
|
||||
try:
|
||||
n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
n_head = config["n_head"]
|
||||
|
||||
if not os.path.isdir(working_dir):
|
||||
os.mkdir(working_dir)
|
||||
|
||||
if layer_name == "start":
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
||||
input_embeds = shark_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/word_embeddings_layernorm.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings_layernorm",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
f"{working_dir}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
hidden_states = shark_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt")
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
|
||||
attention_mask = torch.tensor(attention_mask).float()
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
n_head,
|
||||
hidden_states.dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
torch.save(alibi, f"{working_dir}/alibi.pt")
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
torch.save(causal_mask, f"{working_dir}/causal_mask.pt")
|
||||
|
||||
elif layer_name in [str(x) for x in range(n_layer)]:
|
||||
hidden_states = torch.load(
|
||||
f"{working_dir}/hidden_states_{layer_name}.pt"
|
||||
)
|
||||
alibi = torch.load(f"{working_dir}/alibi.pt")
|
||||
causal_mask = torch.load(f"{working_dir}/causal_mask.pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/bloom_block_{layer_name}.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/bloom_block_{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
f"{working_dir}/bloom_block_{layer_name}.vmfb"
|
||||
)
|
||||
|
||||
output = shark_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
|
||||
torch.save(
|
||||
hidden_states,
|
||||
f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt",
|
||||
)
|
||||
|
||||
elif layer_name == "end":
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/ln_f",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
||||
|
||||
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
|
||||
|
||||
hidden_states = shark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
if config["n_embed"] == 14336:
|
||||
|
||||
def get_state_dict():
|
||||
d = torch.load(
|
||||
f"{working_dir}/pytorch_model_00001-of-00072.bin"
|
||||
)
|
||||
return OrderedDict(
|
||||
(k.replace("word_embeddings.", ""), v)
|
||||
for k, v in d.items()
|
||||
)
|
||||
|
||||
def load_causal_lm_head():
|
||||
linear = nn.utils.skip_init(
|
||||
nn.Linear, 14336, 250880, bias=False, dtype=torch.float
|
||||
)
|
||||
linear.load_state_dict(get_state_dict(), strict=False)
|
||||
return linear.float()
|
||||
|
||||
lm_head = load_causal_lm_head()
|
||||
|
||||
logits = lm_head(torch.tensor(hidden_states).float())
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/lm_head",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
||||
|
||||
logits = shark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
|
||||
next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1))
|
||||
|
||||
f = open(f"{working_dir}/prompt.txt", "w+")
|
||||
f.write(prompt + next_token)
|
||||
f.close()
|
||||
@@ -1,9 +1,6 @@
|
||||
# Stable Diffusion Fine Tuning
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
## Installation (Linux)
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
|
||||
@@ -14,21 +11,21 @@ source shark.venv/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
### Install dependencies
|
||||
## Install dependencies
|
||||
|
||||
# Run the following installation commands:
|
||||
### Run the following installation commands:
|
||||
```
|
||||
pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
pip install accelerate transformers ftfy
|
||||
```
|
||||
|
||||
# Build torch-mlir with the following branch:
|
||||
### Build torch-mlir with the following branch:
|
||||
|
||||
Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops
|
||||
and build it locally. You can find the instructions for using locally build Torch-MLIR,
|
||||
here: https://github.com/nod-ai/SHARK#how-to-use-your-locally-built-iree--torch-mlir-with-shark
|
||||
|
||||
### Run the Stable diffusion fine tuning
|
||||
## Run the Stable diffusion fine tuning
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
@@ -42,3 +39,5 @@ For example, you can run the training for a limited set of steps via the dynamo
|
||||
```
|
||||
python stable_diffusion_fine_tuning.py --training_steps=1 --inference_steps=1 --num_inference_samples=1 --train_batch_size=1 --use_torchdynamo=True
|
||||
```
|
||||
|
||||
You can also specify the device to be used via the flag `--device`. The default value is `cpu`, for GPU execution you can specify `--device="cuda"`.
|
||||
|
||||
@@ -445,8 +445,11 @@ freeze_params(params_to_freeze)
|
||||
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
# For the dynamo path default compilation device is `cpu`, since torch-mlir
|
||||
# supports only that. Therefore, convert to device only for PyTorch path.
|
||||
if not args.use_torchdynamo:
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
|
||||
# Keep vae in eval mode as we don't train it
|
||||
vae.eval()
|
||||
@@ -885,7 +888,9 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
scheduler=DPMSolverMultistepScheduler.from_pretrained(
|
||||
hyperparameters["output_dir"], subfolder="scheduler"
|
||||
),
|
||||
).to(args.device)
|
||||
)
|
||||
if not args.use_torchdynamo:
|
||||
pipe.to(args.device)
|
||||
|
||||
# Run the Stable Diffusion pipeline
|
||||
# Don't forget to use the placeholder token in your prompt
|
||||
|
||||
@@ -19,10 +19,14 @@ import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd):
|
||||
def run_cmd(cmd, debug=False):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
print(cmd)
|
||||
print("\n\n")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
@@ -31,8 +35,9 @@ def run_cmd(cmd):
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
result_str = result.stdout.decode()
|
||||
return result_str
|
||||
stdout = result.stdout.decode()
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
@@ -90,6 +90,7 @@ def build_benchmark_args(
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
# if time_extractor:
|
||||
# benchmark_cl.append(time_extractor)
|
||||
benchmark_cl.append(f"--print_statistics=true")
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
@@ -129,7 +130,8 @@ def build_benchmark_args_non_tensor_input(
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds.
|
||||
Run benchmark command, extract result and return iteration/seconds, host
|
||||
peak memory, and device peak memory.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
@@ -138,15 +140,22 @@ def run_benchmark_module(benchmark_cl):
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * 0.001)
|
||||
iter_per_second = 1.0 / (time_ms * 0.001)
|
||||
|
||||
# Extract peak memory.
|
||||
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
|
||||
host_peak_b = int(host_regex.search(bench_stderr).group(1))
|
||||
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
|
||||
device_peak_b = int(device_regex.search(bench_stderr).group(1))
|
||||
return iter_per_second, host_peak_b, device_peak_b
|
||||
|
||||
@@ -52,11 +52,11 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg"]:
|
||||
return ["--iree-llvm-target-cpu-features=host"]
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo"]:
|
||||
return [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
@@ -188,21 +188,23 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
benchmark_data = run_benchmark_module(benchmark_cl)
|
||||
iter_per_second, _, _ = run_benchmark_module(
|
||||
benchmark_cl
|
||||
)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(benchmark_data) + "\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (benchmark_data * 0.001))
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (benchmark_data * 0.001)
|
||||
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
@@ -293,7 +295,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"]
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
@@ -402,5 +405,10 @@ def get_results(
|
||||
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
config = ireert.Config(device=ireert.get_device(device))
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -44,4 +44,4 @@ def get_iree_cpu_args():
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"-iree-llvm-target-triple={target_triple}"]
|
||||
return [f"--iree-llvmcpu-target-triple={target_triple}"]
|
||||
|
||||
@@ -30,11 +30,10 @@ def get_iree_gpu_args():
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return ["--iree-hal-cuda-disable-loop-nounroll-wa"]
|
||||
return []
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
|
||||
@@ -22,7 +22,8 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
|
||||
@@ -108,4 +108,14 @@ parser.add_argument(
|
||||
help="Enables the --iree-flow-enable-conv-winograd-transform flag.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device_allocator",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Specifies one or more HAL device allocator specs "
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -21,9 +21,17 @@ from shark.iree_utils.benchmark_utils import (
|
||||
from shark.parser import shark_args
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
import csv
|
||||
import os
|
||||
|
||||
TF_CPU_DEVICE = "/CPU:0"
|
||||
TF_GPU_DEVICE = "/GPU:0"
|
||||
|
||||
|
||||
def _bytes_to_mb_str(bytes_: Optional[int]) -> str:
|
||||
return "" if bytes_ is None else f"{bytes_ / 1e6:.6f}"
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
@@ -126,18 +134,26 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
else:
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by PyTorch.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_tf(self, modelname):
|
||||
@@ -155,8 +171,8 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
|
||||
from tank.model_utils_tf import get_tf_model
|
||||
|
||||
# tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
|
||||
tf_device = "/CPU:0"
|
||||
# tf_device = TF_GPU_DEVICE if self.device == "cuda" else TF_CPU_DEVICE
|
||||
tf_device = TF_CPU_DEVICE
|
||||
with tf.device(tf_device):
|
||||
(
|
||||
model,
|
||||
@@ -169,24 +185,41 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*input)
|
||||
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
tf.config.experimental.reset_memory_stats(tf_device)
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*input)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
end = time.time()
|
||||
if tf_device == TF_GPU_DEVICE:
|
||||
memory_info = tf.config.experimental.get_memory_info(tf_device)
|
||||
device_peak_b = memory_info["peak"]
|
||||
else:
|
||||
# tf.config.experimental does not currently support measuring
|
||||
# CPU memory usage.
|
||||
device_peak_b = None
|
||||
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
"", # host_peak_b (CPU usage) is not reported by TensorFlow.
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_c(self):
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
iter_per_second, host_peak_b, device_peak_b = run_benchmark_module(
|
||||
self.benchmark_cl
|
||||
)
|
||||
print(f"Shark-IREE-C benchmark:{iter_per_second} iter/second")
|
||||
return [
|
||||
f"{iter_per_second}",
|
||||
f"{1000/iter_per_second}",
|
||||
_bytes_to_mb_str(host_peak_b),
|
||||
_bytes_to_mb_str(device_peak_b),
|
||||
]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
@@ -196,8 +229,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.run("forward", input_list)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
@@ -324,7 +356,12 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
"tags",
|
||||
"notes",
|
||||
"datetime",
|
||||
"host_memory_mb",
|
||||
"device_memory_mb",
|
||||
"measured_host_memory_mb",
|
||||
"measured_device_memory_mb",
|
||||
]
|
||||
# "frontend" must be the first element.
|
||||
engines = ["frontend", "shark_python", "shark_iree_c"]
|
||||
if shark_args.onnx_bench == True:
|
||||
engines.append("onnxruntime")
|
||||
@@ -336,75 +373,76 @@ for currently supported models. Exiting benchmark ONNX."
|
||||
|
||||
with open("bench_results.csv", mode="a", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=field_names)
|
||||
bench_result = {}
|
||||
bench_result["model"] = modelname
|
||||
bench_info = {}
|
||||
bench_info["model"] = modelname
|
||||
bench_info["dialect"] = self.mlir_dialect
|
||||
bench_info["iterations"] = shark_args.num_iterations
|
||||
if dynamic == True:
|
||||
bench_result["shape_type"] = "dynamic"
|
||||
bench_info["shape_type"] = "dynamic"
|
||||
else:
|
||||
bench_result["shape_type"] = "static"
|
||||
bench_result["device"] = device_str
|
||||
bench_info["shape_type"] = "static"
|
||||
bench_info["device"] = device_str
|
||||
if "fp16" in modelname:
|
||||
bench_result["data_type"] = "float16"
|
||||
bench_info["data_type"] = "float16"
|
||||
else:
|
||||
bench_result["data_type"] = inputs[0].dtype
|
||||
bench_info["data_type"] = inputs[0].dtype
|
||||
|
||||
for e in engines:
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
) = ["", "", ""]
|
||||
engine_result = {}
|
||||
if e == "frontend":
|
||||
bench_result["engine"] = frontend
|
||||
engine_result["engine"] = frontend
|
||||
if check_requirements(frontend):
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_frontend(modelname)
|
||||
self.frontend_result = bench_result["ms/iter"]
|
||||
bench_result["vs. PyTorch/TF"] = "baseline"
|
||||
self.frontend_result = engine_result["ms/iter"]
|
||||
engine_result["vs. PyTorch/TF"] = "baseline"
|
||||
(
|
||||
bench_result["param_count"],
|
||||
bench_result["tags"],
|
||||
bench_result["notes"],
|
||||
engine_result["param_count"],
|
||||
engine_result["tags"],
|
||||
engine_result["notes"],
|
||||
) = self.get_metadata(modelname)
|
||||
else:
|
||||
self.frontend_result = None
|
||||
continue
|
||||
|
||||
elif e == "shark_python":
|
||||
bench_result["engine"] = "shark_python"
|
||||
engine_result["engine"] = "shark_python"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_python(inputs)
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "shark_iree_c":
|
||||
bench_result["engine"] = "shark_iree_c"
|
||||
engine_result["engine"] = "shark_iree_c"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
engine_result["host_memory_mb"],
|
||||
engine_result["device_memory_mb"],
|
||||
) = self.benchmark_c()
|
||||
|
||||
bench_result[
|
||||
engine_result[
|
||||
"vs. PyTorch/TF"
|
||||
] = self.compare_bench_results(
|
||||
self.frontend_result, bench_result["ms/iter"]
|
||||
self.frontend_result, engine_result["ms/iter"]
|
||||
)
|
||||
|
||||
elif e == "onnxruntime":
|
||||
bench_result["engine"] = "onnxruntime"
|
||||
engine_result["engine"] = "onnxruntime"
|
||||
(
|
||||
bench_result["iter/sec"],
|
||||
bench_result["ms/iter"],
|
||||
engine_result["iter/sec"],
|
||||
engine_result["ms/iter"],
|
||||
) = self.benchmark_onnx(modelname, inputs)
|
||||
|
||||
bench_result["dialect"] = self.mlir_dialect
|
||||
bench_result["iterations"] = shark_args.num_iterations
|
||||
bench_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_result)
|
||||
engine_result["datetime"] = str(datetime.now())
|
||||
writer.writerow(bench_info | engine_result)
|
||||
|
||||
@@ -150,10 +150,14 @@ def download_model(
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
print(f"Downloading artifacts for model {model_name}...")
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
elif shark_args.force_update_tank == True:
|
||||
print(f"Force-updating artifacts for model {model_name}...")
|
||||
print(
|
||||
f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
|
||||
)
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
else:
|
||||
if not _internet_connected():
|
||||
@@ -190,8 +194,14 @@ def download_model(
|
||||
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
try:
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
except FileNotFoundError:
|
||||
from tank.generate_sharktank import gen_shark_files
|
||||
|
||||
tank_dir = WORKDIR
|
||||
gen_shark_files(model_name, frontend, tank_dir)
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
|
||||
@@ -4,6 +4,17 @@
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(2**20):
|
||||
file_hash.update(chunk)
|
||||
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
# List of the supported frontends.
|
||||
supported_frontends = {
|
||||
@@ -140,6 +151,7 @@ class SharkImporter:
|
||||
outputs_name = "golden_out.npz"
|
||||
func_file_name = "function_name"
|
||||
model_name_mlir = model_name + "_" + self.frontend + ".mlir"
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
try:
|
||||
inputs = [x.cpu().detach() for x in inputs]
|
||||
except AttributeError:
|
||||
@@ -150,11 +162,11 @@ class SharkImporter:
|
||||
np.savez(os.path.join(dir, inputs_name), *inputs)
|
||||
np.savez(os.path.join(dir, outputs_name), *outputs)
|
||||
np.save(os.path.join(dir, func_file_name), np.array(func_name))
|
||||
|
||||
if self.frontend == "torch":
|
||||
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file:
|
||||
mlir_file.write(mlir_data)
|
||||
|
||||
mlir_hash = create_hash(os.path.join(dir, model_name_mlir))
|
||||
np.save(os.path.join(dir, "hash"), np.array(mlir_hash))
|
||||
return
|
||||
|
||||
def import_debug(
|
||||
@@ -285,6 +297,7 @@ def transform_fx(fx_g):
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
@@ -377,7 +390,10 @@ def import_with_fx(
|
||||
|
||||
golden_values = None
|
||||
if debug:
|
||||
golden_values = model(*inputs)
|
||||
try:
|
||||
golden_values = model(*inputs)
|
||||
except:
|
||||
golden_values = None
|
||||
# TODO: Control the decompositions.
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
|
||||
@@ -35,3 +35,18 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
|
||||
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -63,14 +63,14 @@ if __name__ == "__main__":
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -136,7 +136,7 @@ if __name__ == "__main__":
|
||||
backend = "dylib-llvm-aot"
|
||||
if backend == "dylib-llvm-aot":
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
@@ -146,7 +146,6 @@ if __name__ == "__main__":
|
||||
backend_config = "cuda"
|
||||
args = [
|
||||
"--iree-cuda-llvm-target-arch=sm_80",
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
"--iree-enable-fusion-with-reduction-ops",
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ if __name__ == "__main__":
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -79,14 +79,14 @@ if __name__ == "__main__":
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
|
||||
@@ -33,9 +33,10 @@ def create_hash(file_name):
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
def save_torch_model(torch_model_list):
|
||||
def save_torch_model(torch_model_list, local_tank_cache):
|
||||
from tank.model_utils import (
|
||||
get_hf_model,
|
||||
get_hf_seq2seq_model,
|
||||
get_vision_model,
|
||||
get_hf_img_cls_model,
|
||||
get_fp16_model,
|
||||
@@ -52,14 +53,14 @@ def save_torch_model(torch_model_list):
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.use_tuned = False
|
||||
args.local_tank_cache = WORKDIR
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
@@ -75,7 +76,7 @@ def save_torch_model(torch_model_list):
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
debug=True,
|
||||
sharktank_dir=WORKDIR,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
@@ -84,13 +85,15 @@ def save_torch_model(torch_model_list):
|
||||
model, input, _ = get_vision_model(torch_model_name)
|
||||
elif model_type == "hf":
|
||||
model, input, _ = get_hf_model(torch_model_name)
|
||||
elif model_type == "hf_seq2seq":
|
||||
model, input, _ = get_hf_seq2seq_model(torch_model_name)
|
||||
elif model_type == "hf_img_cls":
|
||||
model, input, _ = get_hf_img_cls_model(torch_model_name)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
torch_model_dir = os.path.join(
|
||||
WORKDIR, str(torch_model_name) + "_torch"
|
||||
local_tank_cache, str(torch_model_name) + "_torch"
|
||||
)
|
||||
os.makedirs(torch_model_dir, exist_ok=True)
|
||||
|
||||
@@ -105,12 +108,6 @@ def save_torch_model(torch_model_list):
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
torch_model_dir, torch_model_name + "_torch" + ".mlir"
|
||||
)
|
||||
)
|
||||
np.save(os.path.join(torch_model_dir, "hash"), np.array(mlir_hash))
|
||||
# Generate torch dynamic models.
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
@@ -121,12 +118,14 @@ def save_torch_model(torch_model_list):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
def save_tf_model(tf_model_list, local_tank_cache):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -152,15 +151,19 @@ def save_tf_model(tf_model_list):
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
if model_type == "keras":
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
if model_type == "TFhf":
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
@@ -178,7 +181,7 @@ def save_tf_model(tf_model_list):
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list):
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
@@ -190,7 +193,7 @@ def save_tflite_model(tflite_model_list):
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
WORKDIR, str(tflite_model_name) + "_tflite"
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
@@ -225,6 +228,45 @@ def save_tflite_model(tflite_model_list):
|
||||
)
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir):
|
||||
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
||||
# TODO: Add TFlite support.
|
||||
import tempfile
|
||||
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir)
|
||||
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir)
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
@@ -265,17 +307,19 @@ if __name__ == "__main__":
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tf_model_list.csv"
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(torch_model_csv)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
)
|
||||
save_torch_model(torch_model_csv, WORKDIR)
|
||||
save_tf_model(tf_model_csv, WORKDIR)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR)
|
||||
@@ -31,4 +31,12 @@ xlm-roberta-base,False,False,-,-,-
|
||||
facebook/convnext-tiny-224,False,False,-,-,-
|
||||
efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv"
|
||||
mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
gpt2,True,False,110M,"nlp;transformer-decoder;auto-regressive","12 layers, 768 hidden units, 12 attention heads"
|
||||
t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
|
||||
|
@@ -7,6 +7,8 @@ import sys
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
||||
vision_models = [
|
||||
"alexnet",
|
||||
"resnet101",
|
||||
@@ -17,6 +19,8 @@ vision_models = [
|
||||
"wide_resnet50_2",
|
||||
"mobilenet_v3_small",
|
||||
"mnasnet1_0",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
]
|
||||
hf_img_cls_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
@@ -25,6 +29,10 @@ hf_img_cls_models = [
|
||||
"microsoft/beit-base-patch16-224-pt22k-ft22k",
|
||||
"nvidia/mit-b0",
|
||||
]
|
||||
hf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
|
||||
|
||||
def get_torch_model(modelname):
|
||||
@@ -32,6 +40,8 @@ def get_torch_model(modelname):
|
||||
return get_vision_model(modelname)
|
||||
elif modelname in hf_img_cls_models:
|
||||
return get_hf_img_cls_model(modelname)
|
||||
elif modelname in hf_seq2seq_models:
|
||||
return get_hf_seq2seq_model(modelname)
|
||||
elif "fp16" in modelname:
|
||||
return get_fp16_model(modelname)
|
||||
else:
|
||||
@@ -85,6 +95,7 @@ def get_hf_img_cls_model(name):
|
||||
# test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1)
|
||||
# print("test_input.shape: ", test_input.shape)
|
||||
# test_input.shape: torch.Size([1, 3, 224, 224])
|
||||
test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1)
|
||||
actual_out = model(test_input)
|
||||
# print("actual_out.shape: ", actual_out.shape)
|
||||
# actual_out.shape: torch.Size([1, 1000])
|
||||
@@ -121,11 +132,52 @@ def get_hf_model(name):
|
||||
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
test_input = torch.randint(2, (BATCH_SIZE, 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
|
||||
class HFSeq2SeqLanguageModel(torch.nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
from transformers import AutoTokenizer, T5Model
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
self.model = T5Model.from_pretrained(model_name, return_dict=True)
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.forward(
|
||||
input_ids, decoder_input_ids=decoder_input_ids
|
||||
)[0]
|
||||
|
||||
|
||||
def get_hf_seq2seq_model(name):
|
||||
m = HFSeq2SeqLanguageModel(name)
|
||||
encoded_input_ids = m.preprocess_input(
|
||||
"Studies have been shown that owning a dog is good for you"
|
||||
).input_ids
|
||||
decoder_input_ids = m.preprocess_input("Studies show that").input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
@@ -144,24 +196,50 @@ class VisionModule(torch.nn.Module):
|
||||
def get_vision_model(torch_model):
|
||||
import torchvision.models as models
|
||||
|
||||
default_image_size = (224, 224)
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet50_fp16": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
"alexnet": (models.alexnet(weights="DEFAULT"), default_image_size),
|
||||
"resnet18": (models.resnet18(weights="DEFAULT"), default_image_size),
|
||||
"resnet50": (models.resnet50(weights="DEFAULT"), default_image_size),
|
||||
"resnet50_fp16": (
|
||||
models.resnet50(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"resnet101": (models.resnet101(weights="DEFAULT"), default_image_size),
|
||||
"squeezenet1_0": (
|
||||
models.squeezenet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"wide_resnet50_2": (
|
||||
models.wide_resnet50_2(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mobilenet_v3_small": (
|
||||
models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
"mnasnet1_0": (
|
||||
models.mnasnet1_0(weights="DEFAULT"),
|
||||
default_image_size,
|
||||
),
|
||||
# EfficientNet input image size varies on the size of the model.
|
||||
"efficientnet_b0": (
|
||||
models.efficientnet_b0(weights="DEFAULT"),
|
||||
(224, 224),
|
||||
),
|
||||
"efficientnet_b7": (
|
||||
models.efficientnet_b7(weights="DEFAULT"),
|
||||
(600, 600),
|
||||
),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
fp16_model = None
|
||||
if "fp16" in torch_model:
|
||||
fp16_model = True
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
torch_model, input_image_size = vision_models_dict[torch_model]
|
||||
model = VisionModule(torch_model)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
test_input = torch.randn(BATCH_SIZE, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
@@ -209,6 +287,7 @@ def get_fp16_model(torch_model):
|
||||
model = BertHalfPrecisionModel(modelname)
|
||||
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
||||
text = "Replace me by any text you like."
|
||||
text = [text] * BATCH_SIZE
|
||||
test_input_fp16 = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
|
||||
@@ -7,11 +7,15 @@ from transformers import (
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
################################## MHLO/TF models #########################################
|
||||
# TODO : Generate these lists or fetch model source from tank/tf/tf_model_list.csv
|
||||
keras_models = ["resnet50", "efficientnet-v2-s"]
|
||||
keras_models = [
|
||||
"resnet50",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b7",
|
||||
"efficientnet-v2-s",
|
||||
]
|
||||
maskedlm_models = [
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
@@ -32,9 +36,16 @@ maskedlm_models = [
|
||||
"hf-internal-testing/tiny-random-flaubert",
|
||||
"xlm-roberta",
|
||||
]
|
||||
causallm_models = [
|
||||
"gpt2",
|
||||
]
|
||||
tfhf_models = [
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
]
|
||||
tfhf_seq2seq_models = [
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
]
|
||||
img_models = [
|
||||
"google/vit-base-patch16-224",
|
||||
"facebook/convnext-tiny-224",
|
||||
@@ -45,23 +56,35 @@ def get_tf_model(name):
|
||||
if name in keras_models:
|
||||
return get_keras_model(name)
|
||||
elif name in maskedlm_models:
|
||||
return get_masked_lm_model(name)
|
||||
elif name in causallm_models:
|
||||
return get_causal_lm_model(name)
|
||||
elif name in tfhf_models:
|
||||
return get_TFhf_model(name)
|
||||
elif name in img_models:
|
||||
return get_causal_image_model(name)
|
||||
elif name in tfhf_seq2seq_models:
|
||||
return get_tfhf_seq2seq_model(name)
|
||||
else:
|
||||
raise Exception(
|
||||
"TF model not found! Please check that the modelname has been input correctly."
|
||||
)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
##################### Tensorflow Hugging Face Bert Models ###################################
|
||||
BERT_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -87,21 +110,31 @@ def get_TFhf_model(name):
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
text = [text] * BATCH_SIZE
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
max_length=BERT_MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
test_input = [
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(encoded_input["input_ids"], dtype=tf.int32),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["attention_mask"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
tf.reshape(
|
||||
tf.convert_to_tensor(
|
||||
encoded_input["token_type_ids"], dtype=tf.int32
|
||||
),
|
||||
[BATCH_SIZE, BERT_MAX_SEQUENCE_LENGTH],
|
||||
),
|
||||
]
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -115,34 +148,41 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
|
||||
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCasualLM
|
||||
|
||||
|
||||
# Tokenizer for language models
|
||||
def preprocess_input(
|
||||
model_name, text="This is just used to compile the model"
|
||||
model_name, max_length, text="This is just used to compile the model"
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = [text] * BATCH_SIZE
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_length,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Masked LM Models ###################################
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
MASKED_LM_MAX_SEQUENCE_LENGTH = 128
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_maskedlm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, MASKED_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForMaskedLM
|
||||
class MaskedLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(MaskedLM, self).__init__()
|
||||
@@ -156,19 +196,139 @@ class MaskedLM(tf.Module):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
def get_masked_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
model = MaskedLM(hf_name)
|
||||
encoded_input = preprocess_input(hf_name, text)
|
||||
encoded_input = preprocess_input(
|
||||
hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text
|
||||
)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Causal LM Models ###################################
|
||||
|
||||
from transformers import AutoConfig, TFAutoModelForCausalLM, TFGPT2Model
|
||||
|
||||
CAUSAL_LM_MAX_SEQUENCE_LENGTH = 1024
|
||||
|
||||
input_signature_causallm = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, CAUSAL_LM_MAX_SEQUENCE_LENGTH], dtype=tf.int32
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForCausalLM
|
||||
# For more background, see:
|
||||
# https://huggingface.co/blog/tf-xla-generate
|
||||
class CausalLM(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(CausalLM, self).__init__()
|
||||
# Decoder-only models need left padding.
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, padding_side="left", pad_token="</s>"
|
||||
)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": CAUSAL_LM_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFGPT2Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(
|
||||
input_ids=x, attention_mask=y
|
||||
)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_causallm, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
model = CausalLM(hf_name)
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input = model.preprocess_input(batched_text)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorflowHugging Face Seq2SeqLM Models ###################################
|
||||
|
||||
# We use a maximum sequence length of 512 since this is the default used in the T5 config.
|
||||
T5_MAX_SEQUENCE_LENGTH = 512
|
||||
|
||||
input_signature_t5 = [
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="input_ids",
|
||||
),
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH],
|
||||
dtype=tf.int32,
|
||||
name="attention_mask",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TFHFSeq2SeqLanguageModel(tf.Module):
|
||||
def __init__(self, model_name):
|
||||
super(TFHFSeq2SeqLanguageModel, self).__init__()
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFT5Model,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenization_kwargs = {
|
||||
"pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH,
|
||||
"padding": True,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
self.model = TFT5Model.from_pretrained(model_name, return_dict=True)
|
||||
self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0]
|
||||
|
||||
def preprocess_input(self, text):
|
||||
return self.tokenizer(text, **self.tokenization_kwargs)
|
||||
|
||||
@tf.function(input_signature=input_signature_t5, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.model.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
def get_tfhf_seq2seq_model(name):
|
||||
m = TFHFSeq2SeqLanguageModel(name)
|
||||
text = "Studies have been shown that owning a dog is good for you"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
encoded_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
|
||||
text = "Studies show that"
|
||||
batched_text = [text] * BATCH_SIZE
|
||||
decoder_input_ids = m.preprocess_input(batched_text).input_ids
|
||||
decoder_input_ids = m.model._shift_right(decoder_input_ids)
|
||||
|
||||
test_input = (encoded_input_ids, decoder_input_ids)
|
||||
actual_out = m.forward(*test_input)
|
||||
return m, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorFlow Keras Resnet Models #########################################################
|
||||
# Static shape, including batch size (1).
|
||||
# Can be dynamic once dynamic shape support is ready.
|
||||
RESNET_INPUT_SHAPE = [1, 224, 224, 3]
|
||||
EFFICIENTNET_INPUT_SHAPE = [1, 384, 384, 3]
|
||||
RESNET_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_V2_S_INPUT_SHAPE = [BATCH_SIZE, 384, 384, 3]
|
||||
EFFICIENTNET_B0_INPUT_SHAPE = [BATCH_SIZE, 224, 224, 3]
|
||||
EFFICIENTNET_B7_INPUT_SHAPE = [BATCH_SIZE, 600, 600, 3]
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
@@ -195,25 +355,79 @@ class ResNetModule(tf.Module):
|
||||
return tf.keras.applications.resnet50.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetModule(tf.Module):
|
||||
class EfficientNetB0Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
super(EfficientNetB0Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_INPUT_SHAPE[1:]),
|
||||
input_shape=tuple(EFFICIENTNET_B0_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(EFFICIENTNET_INPUT_SHAPE, tf.float32)],
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B0_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_INPUT_SHAPE
|
||||
return EFFICIENTNET_B0_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetB7Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetB7Module, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet.EfficientNetB7(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_B7_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_B7_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_B7_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(image)
|
||||
|
||||
|
||||
class EfficientNetV2SModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(EfficientNetV2SModule, self).__init__()
|
||||
self.m = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
|
||||
weights="imagenet",
|
||||
include_top=True,
|
||||
input_shape=tuple(EFFICIENTNET_V2_S_INPUT_SHAPE[1:]),
|
||||
)
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(EFFICIENTNET_V2_S_INPUT_SHAPE, tf.float32)
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
def input_shape(self):
|
||||
return EFFICIENTNET_V2_S_INPUT_SHAPE
|
||||
|
||||
def preprocess_input(self, image):
|
||||
return tf.keras.applications.efficientnet_v2.preprocess_input(image)
|
||||
@@ -224,12 +438,17 @@ def load_image(path_to_image, width, height, channels):
|
||||
image = tf.image.decode_image(image, channels=channels)
|
||||
image = tf.image.resize(image, (width, height))
|
||||
image = image[tf.newaxis, :]
|
||||
image = tf.tile(image, [BATCH_SIZE, 1, 1, 1])
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
if modelname == "efficientnet-v2-s":
|
||||
model = EfficientNetModule()
|
||||
model = EfficientNetV2SModule()
|
||||
elif modelname == "efficientnet_b0":
|
||||
model = EfficientNetB0Module()
|
||||
elif modelname == "efficientnet_b7":
|
||||
model = EfficientNetB7Module()
|
||||
else:
|
||||
model = ResNetModule()
|
||||
|
||||
@@ -256,7 +475,7 @@ import requests
|
||||
|
||||
# Create a set of input signature.
|
||||
input_signature_img_cls = [
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
@@ -304,6 +523,9 @@ def preprocess_input_image(model_name):
|
||||
)
|
||||
# inputs: {'pixel_values': <tf.Tensor: shape=(1, 3, 224, 224), dtype=float32, numpy=array([[[[]]]], dtype=float32)>}
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
inputs["pixel_values"] = tf.tile(
|
||||
inputs["pixel_values"], [BATCH_SIZE, 1, 1, 1]
|
||||
)
|
||||
|
||||
return [inputs[str(*inputs)]]
|
||||
|
||||
|
||||
@@ -48,7 +48,9 @@ def load_csv_and_convert(filename, gen=False):
|
||||
)
|
||||
# This is a pytest workaround
|
||||
if gen:
|
||||
with open("tank/dict_configs.py", "w+") as out:
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "dict_configs.py"), "w+"
|
||||
) as out:
|
||||
out.write("ALL = [\n")
|
||||
for c in model_configs:
|
||||
out.write(str(c) + ",\n")
|
||||
@@ -68,7 +70,9 @@ def get_valid_test_params():
|
||||
dynamic_list = (True, False)
|
||||
# TODO: This is soooo ugly, but for some reason creating the dict at runtime
|
||||
# results in strange pytest failures.
|
||||
load_csv_and_convert("tank/all_models.csv", True)
|
||||
load_csv_and_convert(
|
||||
os.path.join(os.path.dirname(__file__), "all_models.csv"), True
|
||||
)
|
||||
from tank.dict_configs import ALL
|
||||
|
||||
config_list = ALL
|
||||
|
||||
@@ -19,3 +19,10 @@ facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
|
||||
|
@@ -18,4 +18,10 @@ nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encod
|
||||
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
|
||||
|
4
tank/torch_sd_list.csv
Normal file
4
tank/torch_sd_list.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
|
||||
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
|
||||
|
Reference in New Issue
Block a user