mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
34 Commits
20230205.4
...
20230213.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
899cb9cc1f | ||
|
|
0464c7e558 | ||
|
|
f64e1fb926 | ||
|
|
ef7d31293d | ||
|
|
6d54eb68dc | ||
|
|
30eb10c990 | ||
|
|
591bbcd058 | ||
|
|
99aa77d036 | ||
|
|
9c13f1e635 | ||
|
|
24af983cfb | ||
|
|
67842a7525 | ||
|
|
3159a6f3e1 | ||
|
|
b2f3c96835 | ||
|
|
6582475955 | ||
|
|
41ee65b377 | ||
|
|
83fe477066 | ||
|
|
4ca84ee4ee | ||
|
|
c28cc4c919 | ||
|
|
e9864cb3f7 | ||
|
|
83c69ecd49 | ||
|
|
3595b4aaff | ||
|
|
3a9cfe113a | ||
|
|
c9966127da | ||
|
|
51300d33a7 | ||
|
|
5af124c5a5 | ||
|
|
eeb20b531a | ||
|
|
9dca842c22 | ||
|
|
1eb9436836 | ||
|
|
9604d9ce81 | ||
|
|
481d0553d8 | ||
|
|
60035cd63a | ||
|
|
d35f992ace | ||
|
|
157ae64f9d | ||
|
|
ffa17f6057 |
10
.github/workflows/nightly.yml
vendored
10
.github/workflows/nightly.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -52,10 +52,10 @@ jobs:
|
||||
./setup_venv.ps1
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
#signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
#signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
|
||||
11
.github/workflows/test-models.yml
vendored
11
.github/workflows/test-models.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
matrix:
|
||||
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
@@ -111,7 +111,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cpu
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
|
||||
@@ -121,7 +121,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cuda
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
@@ -136,7 +136,7 @@ jobs:
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
@@ -144,7 +144,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
|
||||
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
@@ -158,4 +158,5 @@ jobs:
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
./shark.venv/Scripts/activate
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -159,6 +159,9 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
@@ -172,3 +175,10 @@ onnx_models/
|
||||
|
||||
# Generated images
|
||||
generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
|
||||
from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
import os
|
||||
|
||||
if "AMD_ENABLE_LLPC" not in os.environ:
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from PIL import Image, PngImagePlugin
|
||||
from datetime import datetime as dt
|
||||
from dataclasses import dataclass
|
||||
from csv import DictWriter
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs correspoding to it.
|
||||
def save_output_img(output_img):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": args.hf_model_id,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"IMG_INPUT": args.img_path,
|
||||
"SEED": args.seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
img2img_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def img2img_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global img2img_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = init_image
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
if image is None:
|
||||
return None, "An Initial Image is required"
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.use_tuned = True
|
||||
args.import_mlir = True
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "runwayml/stable-diffusion-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
if not img2img_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
img2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
img2img_obj.log = ""
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
save_output_img(generated_imgs[0])
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
# Adjust for height and width based on model
|
||||
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0])
|
||||
print(text_output)
|
||||
|
||||
350
apps/stable_diffusion/scripts/inpaint.py
Normal file
350
apps/stable_diffusion/scripts/inpaint.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
|
||||
if "AMD_ENABLE_LLPC" not in os.environ:
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from PIL import Image, PngImagePlugin
|
||||
from datetime import datetime as dt
|
||||
from dataclasses import dataclass
|
||||
from csv import DictWriter
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
# This has to come before importing cache objects
|
||||
if args.clear_all:
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
from glob import glob
|
||||
import shutil
|
||||
|
||||
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
if os.path.exists(yaml):
|
||||
os.remove(yaml)
|
||||
home = os.path.expanduser("~")
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.getenv("LOCALAPPDATA")
|
||||
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
|
||||
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"IMG_INPUT": args.img_path,
|
||||
"MASK_INPUT": args.mask_path,
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
}
|
||||
|
||||
with open(csv_path, "a") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
inpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: Image,
|
||||
mask_image: Image,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global inpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
if not inpaint_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
inpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
inpaint_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = inpaint_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
inpaint_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
@@ -18,6 +18,7 @@ from apps.stable_diffusion.src import (
|
||||
Text2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,8 +60,8 @@ if args.clear_all:
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
|
||||
|
||||
# save output images and the inputs correspoding to it.
|
||||
def save_output_img(output_img):
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
generated_imgs_path = Path(output_path, "generated_imgs")
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
@@ -68,9 +69,13 @@ def save_output_img(output_img):
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = os.path.basename(args.ckpt_loc)
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
@@ -81,7 +86,7 @@ def save_output_img(output_img):
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
@@ -93,11 +98,11 @@ def save_output_img(output_img):
|
||||
)
|
||||
|
||||
new_entry = {
|
||||
"VARIANT": args.hf_model_id,
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": args.seed,
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
@@ -133,6 +138,7 @@ def txt2img_inf(
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
@@ -150,7 +156,6 @@ def txt2img_inf(
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
|
||||
@@ -211,6 +216,7 @@ def txt2img_inf(
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
@@ -227,30 +233,38 @@ def txt2img_inf(
|
||||
|
||||
start_time = time.time()
|
||||
txt2img_obj.log = ""
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = txt2img_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
txt2img_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
save_output_img(generated_imgs[0])
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
@@ -263,12 +277,14 @@ if __name__ == "__main__":
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
@@ -278,32 +294,40 @@ if __name__ == "__main__":
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
save_output_img(generated_imgs[0])
|
||||
print(text_output)
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
@@ -30,7 +30,8 @@ datas += [
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
( 'web/logos/*', 'logos' )
|
||||
( 'web/ui/css/*', 'css' ),
|
||||
( 'web/ui/logos/*', 'logos' )
|
||||
]
|
||||
|
||||
binaries = []
|
||||
|
||||
@@ -4,5 +4,9 @@ from apps.stable_diffusion.src.utils import (
|
||||
prompt_examples,
|
||||
get_available_devices,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
|
||||
@@ -2,6 +2,7 @@ from apps.stable_diffusion.src.models.model_wrappers import (
|
||||
SharkifyStableDiffusionModel,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_unet,
|
||||
get_clip,
|
||||
|
||||
@@ -13,6 +13,9 @@ from apps.stable_diffusion.src.utils import (
|
||||
fetch_or_delete_vmfbs,
|
||||
preprocessCKPT,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,15 +30,19 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
elif shape[i] == "width":
|
||||
new_shape.append(width)
|
||||
elif isinstance(shape[i], str):
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
if "batch_size" in shape[i]:
|
||||
mul_val = int(shape[i].split("*")[0])
|
||||
new_shape.append(batch_size * mul_val)
|
||||
elif "height" in shape[i]:
|
||||
new_shape.append(height * mul_val)
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(width * mul_val)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
|
||||
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae".
|
||||
# Get the input info for various models i.e. "unet", "clip", "vae", "vae_encode".
|
||||
def get_input_info(model_info, max_len, width, height, batch_size):
|
||||
dtype_config = {"f32": torch.float32, "i64": torch.int64}
|
||||
input_map = defaultdict(list)
|
||||
@@ -65,6 +72,7 @@ class SharkifyStableDiffusionModel:
|
||||
self,
|
||||
model_id: str,
|
||||
custom_weights: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_len: int = 64,
|
||||
width: int = 512,
|
||||
@@ -79,12 +87,13 @@ class SharkifyStableDiffusionModel:
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
if self.custom_weights != "":
|
||||
assert self.custom_weights.lower().endswith(
|
||||
if custom_weights != "":
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
self.model_name = (
|
||||
@@ -101,17 +110,21 @@ class SharkifyStableDiffusionModel:
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
# We need a better naming convention for the .vmfbs because despite
|
||||
# using the custom model variant the .vmfb names remain the same and
|
||||
# it'll always pick up the compiled .vmfb instead of compiling the
|
||||
# custom model.
|
||||
# So, currently, we add `self.model_id` in the `self.model_name` of
|
||||
# .vmfb file.
|
||||
# TODO: Have a better way of naming the vmfbs using self.model_name.
|
||||
model_name = re.sub(r"\W+", "_", self.model_id)
|
||||
if model_name[0] == "_":
|
||||
model_name = model_name[1:]
|
||||
self.model_name = self.model_name + "_" + model_name
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "vae", "vae_encode"]
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
model_config = self.model_name
|
||||
if "vae" == model:
|
||||
if self.custom_vae != "":
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
if not (max_len >= 32 and max_len <= 77):
|
||||
@@ -121,14 +134,40 @@ class SharkifyStableDiffusionModel:
|
||||
if not (height % 8 == 0 and height >= 384):
|
||||
sys.exit("height should be greater than 384 and multiple of 8")
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
latents = self.vae.encode(input).latent_dist.sample()
|
||||
return 0.18215 * latents
|
||||
|
||||
vae_encode = VaeEncodeModel()
|
||||
inputs = tuple(self.inputs["vae_encode"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
shark_vae_encode = compile_through_fx(
|
||||
vae_encode,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=self.model_name["vae_encode"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae_encode
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id if custom_vae == "" else custom_vae,
|
||||
subfolder="vae",
|
||||
)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, input):
|
||||
@@ -144,13 +183,12 @@ class SharkifyStableDiffusionModel:
|
||||
vae = VaeModel()
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
vae_name = "base_vae" if self.base_vae else "vae"
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
model_name=vae_name + self.model_name,
|
||||
model_name=self.model_name["vae"],
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
)
|
||||
return shark_vae
|
||||
@@ -187,7 +225,7 @@ class SharkifyStableDiffusionModel:
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
model_name="unet" + self.model_name,
|
||||
model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -211,46 +249,103 @@ class SharkifyStableDiffusionModel:
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
model_name="clip" + self.model_name,
|
||||
model_name=self.model_name["clip"],
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
def __call__(self):
|
||||
vmfbs = fetch_or_delete_vmfbs(
|
||||
self.model_name, self.base_vae, self.precision
|
||||
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
|
||||
# configiration.
|
||||
def compile_all(self, base_model_id, need_vae_encode):
|
||||
self.inputs = get_input_info(
|
||||
base_models[base_model_id],
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
compiled_unet = self.get_unet()
|
||||
if self.custom_vae != "":
|
||||
print("Plugging in custom Vae")
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
if need_vae_encode:
|
||||
compiled_vae_encode = self.get_vae_encode()
|
||||
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
|
||||
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
|
||||
def __call__(self):
|
||||
# Step 1:
|
||||
# -- Fetch all vmfbs for the model, if present, else delete the lot.
|
||||
need_vae_encode = args.img_path is not None
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
vmfbs = fetch_or_delete_vmfbs(self.model_name, need_vae_encode, self.precision)
|
||||
if vmfbs[0]:
|
||||
print("Loading vmfbs from cache")
|
||||
# -- If all vmfbs are indeed present, we also try and fetch the base
|
||||
# model configuration for running SD with custom checkpoints.
|
||||
if self.custom_weights != "":
|
||||
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
|
||||
if args.hf_model_id == "":
|
||||
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
|
||||
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
|
||||
return vmfbs
|
||||
|
||||
# Step 2:
|
||||
# -- If vmfbs weren't found, we try to see if the base model configuration
|
||||
# for the required SD run is known to us and bypass the retry mechanism.
|
||||
model_to_run = ""
|
||||
if self.custom_weights != "":
|
||||
model_to_run = self.custom_weights
|
||||
assert self.custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
preprocessCKPT(self.custom_weights)
|
||||
else:
|
||||
model_to_run = args.hf_model_id
|
||||
# For custom Vae user can provide either the repo-id or a checkpoint file,
|
||||
# and for a checkpoint file we'd need to process it via Diffusers' script.
|
||||
if self.custom_vae.lower().endswith((".ckpt", ".safetensors")):
|
||||
preprocessCKPT(self.custom_vae)
|
||||
self.custom_vae = get_path_to_diffusers_checkpoint(self.custom_vae)
|
||||
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
|
||||
if base_model_fetched != "":
|
||||
print("Compiling all the models with the fetched base model configuration.")
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = base_model_fetched
|
||||
return self.compile_all(base_model_fetched, need_vae_encode)
|
||||
|
||||
# Step 3:
|
||||
# -- This is the retry mechanism where the base model's configuration is not
|
||||
# known to us and figure that out by trial and error.
|
||||
print("Inferring base model configuration.")
|
||||
for model_id in base_models:
|
||||
self.inputs = get_input_info(
|
||||
base_models[model_id],
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
try:
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
compiled_clip = self.get_clip()
|
||||
if need_vae_encode:
|
||||
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode)
|
||||
else:
|
||||
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode)
|
||||
except Exception as e:
|
||||
if args.enable_stack_trace:
|
||||
traceback.print_exc()
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
# the base model's configuration inferred.
|
||||
fetch_and_update_base_model_id(model_to_run, model_id)
|
||||
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
|
||||
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
|
||||
# model and rely on retrying method to find the input configuration, we should also update
|
||||
# the knowledge of base model id accordingly into `args.hf_model_id`.
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = model_id
|
||||
if need_vae_encode:
|
||||
return (
|
||||
compiled_clip,
|
||||
compiled_unet,
|
||||
compiled_vae,
|
||||
compiled_vae_encode,
|
||||
)
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
sys.exit(
|
||||
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
|
||||
@@ -16,6 +16,8 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
}
|
||||
|
||||
|
||||
@@ -52,6 +54,23 @@ def get_unet():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((height, width)) # Current support for 512x512
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# add noise to data
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
return latents
|
||||
|
||||
def encode_image(self, input_image):
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
image_latents = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=image_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
class InpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_mask_and_masked_image(self, image, mask):
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image, mask_image
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
return all_imgs
|
||||
@@ -89,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
@@ -16,6 +17,7 @@ from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_clip,
|
||||
get_unet,
|
||||
@@ -112,6 +114,8 @@ class StableDiffusionPipeline:
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
@@ -122,6 +126,15 @@ class StableDiffusionPipeline:
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
@@ -177,6 +190,7 @@ class StableDiffusionPipeline:
|
||||
import_mlir: bool,
|
||||
model_id: str,
|
||||
ckpt_loc: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
max_length: int,
|
||||
batch_size: int,
|
||||
@@ -186,9 +200,12 @@ class StableDiffusionPipeline:
|
||||
use_tuned: bool,
|
||||
):
|
||||
if import_mlir:
|
||||
# TODO: Delet this when on-the-fly tuning of models work.
|
||||
use_tuned = False
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
custom_vae,
|
||||
precision,
|
||||
max_len=max_length,
|
||||
batch_size=batch_size,
|
||||
@@ -197,8 +214,23 @@ class StableDiffusionPipeline:
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
)
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
)
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
return cls(
|
||||
get_vae_encode(),
|
||||
get_vae(),
|
||||
get_clip(),
|
||||
get_tokenizer(),
|
||||
get_unet(),
|
||||
scheduler,
|
||||
)
|
||||
return cls(
|
||||
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
|
||||
)
|
||||
|
||||
@@ -21,5 +21,9 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
fetch_or_delete_vmfbs,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,14 @@
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
@@ -77,6 +85,126 @@
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"runwayml/stable-diffusion-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-inpainting": {
|
||||
"unet": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
|
||||
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
|
||||
@@ -42,6 +42,17 @@
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
|
||||
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
|
||||
"stablediffusion/inpaint_v1/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/unet/fp32/length_77/untuned":"unet_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae_encode/fp32/length_77/untuned":"vae_encode_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v1/vae/fp32/length_77/untuned":"vae_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v1/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"stablediffusion/inpaint_v2/unet/fp16/length_77/untuned":"unet_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
|
||||
"stablediffusion/inpaint_v2/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
|
||||
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import io
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import iree_target_map, run_cmd
|
||||
from shark.shark_downloader import (
|
||||
@@ -97,21 +98,25 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
search_op="conv",
|
||||
winograd=True,
|
||||
)
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
return winograd_model, out_file_path
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
winograd_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
return bytecode, out_file_path
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
def dump_after_mlir(input_mlir, model_name, use_winograd):
|
||||
if use_winograd:
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline='builtin.module"
|
||||
"(func.func(iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))' "
|
||||
)
|
||||
@@ -119,11 +124,12 @@ def annotate_with_lower_configs(
|
||||
dump_after = "iree-preprocessing-pad-linalg-ops"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline='builtin.module"
|
||||
"(func.func(iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
|
||||
)
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
device_spec_args = ""
|
||||
device = get_device()
|
||||
if device == "cuda":
|
||||
@@ -151,6 +157,14 @@ def annotate_with_lower_configs(
|
||||
f"2>{args.annotation_output}/dump_after_winograd.mlir "
|
||||
)
|
||||
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
def annotate_with_lower_configs(
|
||||
input_mlir, lowering_config_dir, model_name, use_winograd
|
||||
):
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
dump_after_mlir(input_mlir, model_name, use_winograd)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
@@ -168,10 +182,15 @@ def annotate_with_lower_configs(
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
tuned_model.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
f.close()
|
||||
return tuned_model, out_file_path
|
||||
return bytecode, out_file_path
|
||||
|
||||
|
||||
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
|
||||
@@ -207,7 +226,7 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
|
||||
mlir_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
print(f"Saved the annotated mlir in {output_path}.")
|
||||
return tuned_model, output_path
|
||||
return tuned_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,18 +17,30 @@ p = argparse.ArgumentParser(
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
action="append",
|
||||
default=[],
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
@@ -39,8 +51,8 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
default=-1,
|
||||
help="the seed to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -48,7 +60,7 @@ p.add_argument(
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="the number of inferences to be made in a single `run`.",
|
||||
help="the number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -148,10 +160,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--runs",
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to be generated with random seeds in single execution",
|
||||
help="number of batch to be generated with random seeds in single execution",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -161,6 +173,13 @@ p.add_argument(
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
@@ -279,7 +298,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
|
||||
)
|
||||
@@ -292,7 +311,7 @@ p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the pregress bar animation during image generation",
|
||||
help="flag for removing the progress bar animation during image generation",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
@@ -11,26 +14,30 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
import sys, functools, operator
|
||||
import sys
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
def get_extended_name(model_name):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
return [vmfb_path, extended_name]
|
||||
return extended_name
|
||||
|
||||
|
||||
def get_vmfb_path_name(model_name):
|
||||
vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb")
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
|
||||
vmfb_path = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -44,7 +51,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
os.getcwd(), model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -54,11 +61,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
@@ -93,26 +102,19 @@ def compile_through_fx(
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
if not os.path.exists(tuned_model_path):
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
|
||||
tuned_model, tuned_model_path = sd_model_annotation(
|
||||
mlir_module, model_name
|
||||
)
|
||||
del mlir_module, tuned_model
|
||||
gc.collect()
|
||||
|
||||
with open(tuned_model_path, "rb") as f:
|
||||
mlir_module = f.read()
|
||||
f.close()
|
||||
if "vae" in model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
mlir_module = sd_model_annotation(mlir_module, model_name)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
@@ -250,11 +252,7 @@ def set_init_device_flags():
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in [
|
||||
"sm_80",
|
||||
"sm_84",
|
||||
"sm_86",
|
||||
]:
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
@@ -280,6 +278,8 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
]:
|
||||
args.import_mlir = True
|
||||
|
||||
@@ -360,6 +360,11 @@ def get_opt_flags(model, precision="fp16"):
|
||||
return iree_flags
|
||||
|
||||
|
||||
def get_path_stem(path):
|
||||
path = Path(path)
|
||||
return path.stem
|
||||
|
||||
|
||||
def get_path_to_diffusers_checkpoint(custom_weights):
|
||||
path = Path(custom_weights)
|
||||
diffusers_path = path.parent.absolute()
|
||||
@@ -401,7 +406,7 @@ def preprocessCKPT(custom_weights):
|
||||
|
||||
|
||||
def load_vmfb(vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model else model
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module = SharkInference(mlir_module=None, device=args.device)
|
||||
@@ -409,25 +414,68 @@ def load_vmfb(vmfb_path, model, precision):
|
||||
return shark_module
|
||||
|
||||
|
||||
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
|
||||
# This utility returns vmfbs of Clip, Unet, Vae and Vae_encode, in case all of them
|
||||
# are present; deletes them otherwise.
|
||||
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
|
||||
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
|
||||
def fetch_or_delete_vmfbs(
|
||||
extended_model_name, need_vae_encode, precision="fp32"
|
||||
):
|
||||
vmfb_path = [
|
||||
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
|
||||
get_vmfb_path_name(extended_model_name[model])
|
||||
for model in extended_model_name
|
||||
]
|
||||
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
|
||||
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
|
||||
compiled_models = [None] * 3
|
||||
all_vmfb_present = True
|
||||
compiled_models = []
|
||||
for i in range(3):
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[i]
|
||||
compiled_models.append(None)
|
||||
if need_vae_encode:
|
||||
all_vmfb_present = all_vmfb_present and vmfb_present[3]
|
||||
compiled_models.append(None)
|
||||
|
||||
# We need to delete vmfbs only if some of the models were compiled.
|
||||
if not all_vmfb_present:
|
||||
for i in range(len(vmfb_path)):
|
||||
for i in range(len(compiled_models)):
|
||||
if vmfb_present[i]:
|
||||
os.remove(vmfb_path[i])
|
||||
print("Deleted: ", vmfb_path[i])
|
||||
else:
|
||||
for i in range(len(vmfb_path)):
|
||||
model_name = [model for model in extended_model_name.keys()]
|
||||
for i in range(len(compiled_models)):
|
||||
compiled_models[i] = load_vmfb(
|
||||
vmfb_path[i], model_name[i], precision
|
||||
)
|
||||
return compiled_models
|
||||
|
||||
|
||||
# `fetch_and_update_base_model_id` is a resource utility function which
|
||||
# helps maintaining mapping of the model to run with its base model.
|
||||
# If `base_model` is "", then this function tries to fetch the base model
|
||||
# info for the `model_to_run`.
|
||||
def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
variants_path = os.path.join(os.getcwd(), "variants.json")
|
||||
data = {model_to_run: base_model}
|
||||
json_data = {}
|
||||
if os.path.exists(variants_path):
|
||||
with open(variants_path, "r", encoding="utf-8") as jsonFile:
|
||||
json_data = json.load(jsonFile)
|
||||
# Return with base_model's info if base_model is "".
|
||||
if base_model == "":
|
||||
if model_to_run in json_data:
|
||||
base_model = json_data[model_to_run]
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
json_data.update(data)
|
||||
with open(variants_path, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import glob
|
||||
|
||||
|
||||
if "AMD_ENABLE_LLPC" not in os.environ:
|
||||
os.environ["AMD_ENABLE_LLPC"] = "1"
|
||||
@@ -10,259 +9,16 @@ if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.src import (
|
||||
prompt_examples,
|
||||
args,
|
||||
get_available_devices,
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui import txt2img_web, img2img_web
|
||||
|
||||
|
||||
sd_web = gr.TabbedInterface(
|
||||
[txt2img_web, img2img_web], ["Text-to-Image", "Image-to-Image"]
|
||||
)
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
|
||||
|
||||
|
||||
demo_css = resource_path("css/sd_dark_theme.css")
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
logo2 = Image.open(sdlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
with gr.Column(scale=5, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=logo2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="demo_title",
|
||||
).style(width=150, height=100)
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value="None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value="cyberpunk forest by Salvador Dali",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value="trees, green",
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="SharkEulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1, 4, value=1, step=1, label="Number of Images"
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=512, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=512, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=64,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=50, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(value=-1, precision=0, label="Seed")
|
||||
available_devices = get_available_devices()
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2], height="auto")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=4,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
prompt.submit(
|
||||
txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
stable_diffusion.click(
|
||||
txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(
|
||||
sd_web.queue()
|
||||
sd_web.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 33 KiB |
2
apps/stable_diffusion/web/ui/__init__.py
Normal file
2
apps/stable_diffusion/web/ui/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_web
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_web
|
||||
@@ -166,7 +166,7 @@
|
||||
}
|
||||
|
||||
#demo_title {
|
||||
background-color: : var(--color-background-primary);
|
||||
background-color: var(--color-background-primary);
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
padding-top: 15px;
|
||||
244
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
244
apps/stable_diffusion/web/ui/img2img_ui.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import img2img_inf
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
get_available_devices,
|
||||
)
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
|
||||
demo_css = resource_path("css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(title="Image-to-Image", css=demo_css) as img2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
logo2 = Image.open(sdlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
with gr.Column(scale=5, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=logo2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="demo_title",
|
||||
).style(width=150, height=100)
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(label="Input Image", type="filepath")
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
10,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=False,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
available_devices = get_available_devices()
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2], height="auto")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=4,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=img2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
|
Before Width: | Height: | Size: 10 KiB After Width: | Height: | Size: 10 KiB |
|
Before Width: | Height: | Size: 5.0 KiB After Width: | Height: | Size: 5.0 KiB |
248
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
248
apps/stable_diffusion/web/ui/txt2img_ui.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import txt2img_inf
|
||||
from apps.stable_diffusion.src import (
|
||||
prompt_examples,
|
||||
args,
|
||||
get_available_devices,
|
||||
)
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
|
||||
demo_css = resource_path("css/sd_dark_theme.css")
|
||||
|
||||
with gr.Blocks(title="Text-to-Image", css=demo_css) as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
logo2 = Image.open(sdlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
with gr.Column(scale=5, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=logo2,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="demo_title",
|
||||
).style(width=150, height=100)
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
10,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
available_devices = get_available_devices()
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image")
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2], height="auto")
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=4,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
@@ -29,7 +29,7 @@ def compare_images(new_filename, golden_filename):
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.01:
|
||||
if mean > 0.1:
|
||||
subprocess.run(
|
||||
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
|
||||
)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python generate_sharktank.py --upload=False --ci_tank_dir=True
|
||||
python generate_sharktank.py
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
rm -rf ./test_images
|
||||
mkdir test_images
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
|
||||
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned --beta_models=True
|
||||
|
||||
python build_tools/image_comparison.py -n ./test_images/*.png
|
||||
exit $?
|
||||
@@ -23,8 +23,7 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
os.mkdir("./test_images")
|
||||
os.mkdir("./test_images/golden")
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned"] #'use_tuned']
|
||||
devices = ["vulkan"]
|
||||
tuned_options = ["--no-use_tuned", "use_tuned"]
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
for model_name in hf_model_names:
|
||||
@@ -33,15 +32,19 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"python",
|
||||
"apps/stable_diffusion/scripts/txt2img.py",
|
||||
"--device=" + device,
|
||||
"--output_dir=./test_images/" + model_name,
|
||||
"--prompt=cyberpunk forest by Salvador Dali",
|
||||
"--output_dir="
|
||||
+ os.path.join(os.getcwd(), "test_images", model_name),
|
||||
"--hf_model_id=" + model_name,
|
||||
use_tune,
|
||||
]
|
||||
command += extra_flags
|
||||
generated_image = not subprocess.call(
|
||||
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
command, stdout=subprocess.DEVNULL
|
||||
)
|
||||
if generated_image:
|
||||
print(" ".join(command))
|
||||
print("Successfully generated image")
|
||||
os.makedirs(
|
||||
"./test_images/golden/" + model_name, exist_ok=True
|
||||
)
|
||||
@@ -49,18 +52,16 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"gs://shark_tank/testdata/golden/" + model_name,
|
||||
"./test_images/golden/" + model_name,
|
||||
)
|
||||
comparison = [
|
||||
"python",
|
||||
"build_tools/image_comparison.py",
|
||||
"--golden_url=gs://shark_tank/testdata/golden/"
|
||||
+ model_name
|
||||
+ "/*.png",
|
||||
"--newfile=./test_images/" + model_name + "/*.png",
|
||||
]
|
||||
test_file = glob("./test_images/" + model_name + "/*.png")[0]
|
||||
test_file_path = os.path.join(
|
||||
os.getcwd(), "test_images", model_name, "generated_imgs"
|
||||
)
|
||||
test_file = glob(test_file_path + "/*.png")[0]
|
||||
golden_path = "./test_images/golden/" + model_name + "/*.png"
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
else:
|
||||
print(" ".join(command))
|
||||
print("failed to generate image for this configuration")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -2,18 +2,16 @@
|
||||
"""SHARK Tank"""
|
||||
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
|
||||
# will generate local shark tank folder like this:
|
||||
# HOME
|
||||
# /.local
|
||||
# /shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
# /SHARK
|
||||
# /gen_shark_tank
|
||||
# /albert_lite_base
|
||||
# /...model_name...
|
||||
#
|
||||
|
||||
import os
|
||||
import csv
|
||||
import argparse
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
@@ -267,16 +265,17 @@ if __name__ == "__main__":
|
||||
# old_args = parser.parse_args()
|
||||
|
||||
home = str(Path.home())
|
||||
if args.ci_tank_dir == True:
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tf_model_list.csv"
|
||||
)
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
if args.torch_model_csv:
|
||||
save_torch_model(args.torch_model_csv)
|
||||
|
||||
if args.tf_model_csv:
|
||||
save_tf_model(args.tf_model_csv)
|
||||
|
||||
if args.tflite_model_csv:
|
||||
save_tflite_model(args.tflite_model_csv)
|
||||
save_torch_model(torch_model_csv)
|
||||
save_tf_model(tf_model_csv)
|
||||
save_tflite_model(tflite_model_csv)
|
||||
|
||||
@@ -42,7 +42,7 @@ Green=`tput setaf 2`
|
||||
Yellow=`tput setaf 3`
|
||||
|
||||
# Assume no binary torch-mlir.
|
||||
# Currently available for macOS m1&intel (3.10) and Linux(3.7,3.8,3.9,3.10)
|
||||
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
|
||||
torch_mlir_bin=false
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}Apple macOS detected"
|
||||
@@ -60,12 +60,12 @@ if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
fi
|
||||
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
|
||||
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.10" ]; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
elif [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.7" ] || [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.9" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] ; then
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
else
|
||||
@@ -89,7 +89,7 @@ if [ "$torch_mlir_bin" = true ]; then
|
||||
fi
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
echo "${Yello}Python 3.10 supported on macOS and 3.7,3.8,3.9 and 3.10 on Linux"
|
||||
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
|
||||
echo "${Red}Please build torch-mlir from source in your environment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torchdynamo
|
||||
import torch
|
||||
import torch_mlir
|
||||
import torch._dynamo as torchdynamo
|
||||
from shark.sharkdynamo.utils import make_shark_compiler
|
||||
|
||||
|
||||
|
||||
@@ -15,24 +15,6 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
os.mkdir(path)
|
||||
return path
|
||||
|
||||
|
||||
def dir_file(path):
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_file:{path} is not a valid file"
|
||||
)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
@@ -40,12 +22,6 @@ parser.add_argument(
|
||||
default="cpu",
|
||||
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repro_dir",
|
||||
help="Directory to which module files will be saved for reproduction or debugging.",
|
||||
type=dir_path,
|
||||
default="shark_tmp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
@@ -83,10 +59,16 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update_tank",
|
||||
default=False,
|
||||
default=True,
|
||||
action="store_true",
|
||||
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_update_tank",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled, SHARK downloader will force an update of local shark_tank artifacts for each request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_tank_cache",
|
||||
default=None,
|
||||
|
||||
@@ -82,7 +82,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
shark_args.repro_dir,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
@@ -79,23 +79,21 @@ input_type_to_np_dtype = {
|
||||
# Save the model in the home local so it needn't be fetched everytime in the CI.
|
||||
home = str(Path.home())
|
||||
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
|
||||
custom_path_list = None
|
||||
if shark_args.local_tank_cache is not None:
|
||||
custom_path_list = shark_args.local_tank_cache.split("/")
|
||||
custom_path = shark_args.local_tank_cache
|
||||
|
||||
if os.path.exists(alt_path):
|
||||
WORKDIR = alt_path
|
||||
print(
|
||||
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
|
||||
)
|
||||
if custom_path_list:
|
||||
custom_path = os.path.join(*custom_path_list)
|
||||
if custom_path is not None:
|
||||
if not os.path.exists(custom_path):
|
||||
os.mkdir(custom_path)
|
||||
|
||||
WORKDIR = custom_path
|
||||
|
||||
print(f"Using {WORKDIR} as local shark_tank cache directory.")
|
||||
|
||||
elif os.path.exists(alt_path):
|
||||
WORKDIR = alt_path
|
||||
print(
|
||||
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
|
||||
)
|
||||
else:
|
||||
WORKDIR = os.path.join(home, ".local/shark_tank/")
|
||||
print(
|
||||
@@ -148,15 +146,14 @@ def download_model(
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
if shark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif not check_dir_exists(
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
print(f"Downloading artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
elif shark_args.force_update_tank == True:
|
||||
print(f"Force-updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
else:
|
||||
if not _internet_connected():
|
||||
print(
|
||||
@@ -178,7 +175,11 @@ def download_model(
|
||||
)
|
||||
except FileNotFoundError:
|
||||
upstream_hash = None
|
||||
if local_hash != upstream_hash:
|
||||
if local_hash != upstream_hash and shark_args.update_tank == True:
|
||||
print(f"Updating artifacts for model {model_name}...")
|
||||
download_public_file(full_gs_url, model_dir)
|
||||
|
||||
elif local_hash != upstream_hash:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
|
||||
)
|
||||
|
||||
@@ -81,7 +81,7 @@ class SharkImporter:
|
||||
self.return_str,
|
||||
)
|
||||
|
||||
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
def _tf_mlir(self, func_name, save_dir="."):
|
||||
from iree.compiler import tf as tfc
|
||||
|
||||
return tfc.compile_module(
|
||||
@@ -91,7 +91,7 @@ class SharkImporter:
|
||||
output_file=save_dir,
|
||||
)
|
||||
|
||||
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
|
||||
def _tflite_mlir(self, func_name, save_dir="."):
|
||||
from iree.compiler import tflite as tflitec
|
||||
|
||||
self.mlir_model = tflitec.compile_file(
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from functorch._src.compile_utils import strip_overloads
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
@@ -119,14 +119,19 @@ def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
)
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
linalg_module.operation.write_bytecode(bytecode_stream)
|
||||
mlir_module = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
linalg_module, "forward", mlir_dialect="linalg", device=device
|
||||
mlir_module, mlir_dialect="linalg", device=device
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def forward(*inputs):
|
||||
result = shark_module.forward(inputs)
|
||||
result = shark_module("forward", inputs)
|
||||
result = tuple() if result is None else result
|
||||
return (result,) if was_unwrapped else result
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def get_torch_mlir_module(
|
||||
if jit_trace:
|
||||
ignore_traced_shapes = True
|
||||
|
||||
tempfile.tempdir = shark_args.repro_dir
|
||||
tempfile.tempdir = "."
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
module,
|
||||
|
||||
@@ -136,7 +136,7 @@ class SharkModuleTester:
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.update_tank = self.update_tank
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
|
||||
".use-iree"
|
||||
):
|
||||
@@ -212,12 +212,11 @@ class SharkModuleTester:
|
||||
)
|
||||
|
||||
def save_reproducers(self):
|
||||
# Saves contents of IREE TempFileSaver temporary directory to ./shark_tmp/saved/<test_case>.
|
||||
src = os.path.join(*self.temp_dir.split("/"))
|
||||
saves = os.path.join(".", "shark_tmp", "saved")
|
||||
trg = os.path.join(saves, self.tmp_prefix)
|
||||
if not os.path.isdir(saves):
|
||||
os.mkdir(saves)
|
||||
# Saves contents of IREE TempFileSaver temporary directory to ./{temp_dir}/saved/<test_case>.
|
||||
src = self.temp_dir
|
||||
trg = os.path.join("reproducers", self.tmp_prefix)
|
||||
if not os.path.isdir("reproducers"):
|
||||
os.mkdir("reproducers")
|
||||
if not os.path.isdir(trg):
|
||||
os.mkdir(trg)
|
||||
files = os.listdir(src)
|
||||
@@ -227,10 +226,7 @@ class SharkModuleTester:
|
||||
def upload_repro(self):
|
||||
import subprocess
|
||||
|
||||
src = os.path.join(*self.temp_dir.split("/"))
|
||||
repro_path = os.path.join(
|
||||
".", "shark_tmp", "saved", self.tmp_prefix, "*"
|
||||
)
|
||||
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
|
||||
|
||||
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
|
||||
process = subprocess.run(bashCommand.split())
|
||||
@@ -329,11 +325,8 @@ class SharkModuleTest(unittest.TestCase):
|
||||
)
|
||||
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
|
||||
|
||||
if not os.path.isdir("shark_tmp"):
|
||||
os.mkdir("shark_tmp")
|
||||
|
||||
tempdir = tempfile.TemporaryDirectory(
|
||||
prefix=self.module_tester.tmp_prefix, dir="shark_tmp"
|
||||
prefix=self.module_tester.tmp_prefix, dir="."
|
||||
)
|
||||
self.module_tester.temp_dir = tempdir.name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user