mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
26 Commits
20230218.5
...
20230223.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9039b35ed | ||
|
|
a01154a507 | ||
|
|
1d9204282d | ||
|
|
5ff40a0d2d | ||
|
|
fab6d2e4e0 | ||
|
|
abab59c25f | ||
|
|
c25840b585 | ||
|
|
1b3f9125bb | ||
|
|
b5d9f5ba49 | ||
|
|
1c22aa9c8f | ||
|
|
e1d7fb879c | ||
|
|
e912c42bf0 | ||
|
|
e6841acf36 | ||
|
|
bc4459b6f4 | ||
|
|
9b544491e0 | ||
|
|
9c5415b598 | ||
|
|
040dbc317f | ||
|
|
65775046d8 | ||
|
|
b18bc36127 | ||
|
|
f01c526efd | ||
|
|
16168ab6b3 | ||
|
|
4233218629 | ||
|
|
b63fb36dc0 | ||
|
|
4e92304b89 | ||
|
|
2ae047f1a8 | ||
|
|
6d2a485264 |
11
.github/workflows/nightly.yml
vendored
11
.github/workflows/nightly.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
prerelease: true
|
||||
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
@@ -67,9 +67,9 @@ jobs:
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: dist/*
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
@@ -79,6 +79,7 @@ jobs:
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
id: publish_release
|
||||
@@ -133,7 +134,7 @@ jobs:
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
16
README.md
16
README.md
@@ -27,7 +27,7 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
|
||||
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
|
||||
Download the stable release [539](https://github.com/nod-ai/SHARK/releases/download/20230216.539/shark_sd_20230216_539.exe) or if you are adventurous the latest .exe from [releases page](https://github.com/nod-ai/SHARK/releases).
|
||||
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
|
||||
|
||||
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
@@ -215,14 +215,14 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with SHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
How to use your locally built Torch-MLIR with SHARK:
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
@@ -240,9 +240,15 @@ Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
@@ -266,7 +272,7 @@ Output will include:
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from apps.stable_diffusion.scripts.txt2img import txt2img_inf
|
||||
from apps.stable_diffusion.scripts.img2img import img2img_inf
|
||||
from apps.stable_diffusion.scripts.inpaint import inpaint_inf
|
||||
from apps.stable_diffusion.scripts.outpaint import outpaint_inf
|
||||
|
||||
@@ -104,7 +104,7 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
if not img2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
@@ -136,11 +136,9 @@ def img2img_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if not img2img_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
img2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
@@ -230,6 +228,7 @@ if __name__ == "__main__":
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -35,8 +35,7 @@ schedulers = None
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: Image,
|
||||
mask_image: Image,
|
||||
image_loc,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
@@ -62,6 +61,8 @@ def inpaint_inf(
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = image_loc["image"]
|
||||
args.mask_path = image_loc["mask"]
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
@@ -97,7 +98,7 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
if not inpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
@@ -131,9 +132,6 @@ def inpaint_inf(
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
if not inpaint_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
inpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
@@ -141,6 +139,8 @@ def inpaint_inf(
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
|
||||
275
apps/stable_diffusion/scripts/outpaint.py
Normal file
275
apps/stable_diffusion/scripts/outpaint.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model_id: str
|
||||
ckpt_loc: str
|
||||
precision: str
|
||||
batch_size: int
|
||||
max_length: int
|
||||
height: int
|
||||
width: int
|
||||
device: str
|
||||
|
||||
|
||||
outpaint_obj = None
|
||||
config_obj = None
|
||||
schedulers = None
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def outpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
init_image: str,
|
||||
pixels: int,
|
||||
mask_blur: int,
|
||||
directions: list,
|
||||
noise_q: float,
|
||||
color_variation: float,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
):
|
||||
global outpaint_obj
|
||||
global config_obj
|
||||
global schedulers
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.img_path = init_image
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if not outpaint_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[scheduler]
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
outpaint_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
outpaint_obj.log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
top = True if "up" in directions else False
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = outpaint_obj.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
args.pixels,
|
||||
args.mask_blur,
|
||||
left,
|
||||
right,
|
||||
top,
|
||||
bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
generated_imgs.extend(out_imgs)
|
||||
seeds.append(img_seed)
|
||||
outpaint_obj.log += "\n"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if "inpaint" not in args.hf_model_id:
|
||||
print("Please use inpainting model with --hf_model_id.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.pixels,
|
||||
args.mask_blur,
|
||||
args.left,
|
||||
args.right,
|
||||
args.top,
|
||||
args.bottom,
|
||||
args.noise_q,
|
||||
args.color_variation,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
@@ -94,7 +94,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
)
|
||||
if config_obj != new_config_obj:
|
||||
if not txt2img_obj or config_obj != new_config_obj:
|
||||
config_obj = new_config_obj
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
@@ -105,6 +105,7 @@ def txt2img_inf(
|
||||
args.iree_vulkan_target_triple = ""
|
||||
args.use_tuned = True
|
||||
args.import_mlir = False
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
@@ -126,11 +127,9 @@ def txt2img_inf(
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if not txt2img_obj:
|
||||
sys.exit("text to image pipeline must not return a null value")
|
||||
|
||||
txt2img_obj.scheduler = schedulers[scheduler]
|
||||
|
||||
start_time = time.time()
|
||||
@@ -199,6 +198,7 @@ if __name__ == "__main__":
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
for current_batch in range(args.batch_count):
|
||||
|
||||
@@ -15,12 +15,11 @@ datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -44,7 +43,7 @@ a = Analysis(
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -15,12 +15,11 @@ datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torchvision')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('diffusers')
|
||||
datas += copy_metadata('transformers')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
@@ -42,7 +41,7 @@ a = Analysis(
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
|
||||
hiddenimports=['shark', 'shark.shark_inference', 'apps'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
|
||||
@@ -8,7 +8,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
|
||||
@@ -80,6 +80,7 @@ class SharkifyStableDiffusionModel:
|
||||
batch_size: int = 1,
|
||||
use_base_vae: bool = False,
|
||||
use_tuned: bool = False,
|
||||
low_cpu_mem_usage: bool = False
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -114,6 +115,7 @@ class SharkifyStableDiffusionModel:
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
self.model_name = self.model_name + "_" + get_path_stem(self.model_id)
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
@@ -139,11 +141,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -165,23 +168,26 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
@@ -196,7 +202,7 @@ class SharkifyStableDiffusionModel:
|
||||
x = x * 255.0
|
||||
return x.round()
|
||||
|
||||
vae = VaeModel()
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
shark_vae = compile_through_fx(
|
||||
@@ -211,11 +217,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_unet(self):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
@@ -234,7 +241,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel()
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False]
|
||||
@@ -251,17 +258,18 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
@@ -326,6 +334,8 @@ class SharkifyStableDiffusionModel:
|
||||
if args.hf_model_id == "":
|
||||
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
|
||||
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
|
||||
if not need_vae_encode:
|
||||
return vmfbs[:3]
|
||||
return vmfbs
|
||||
|
||||
# Step 2:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
|
||||
OutpaintPipeline,
|
||||
)
|
||||
|
||||
@@ -41,8 +41,10 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_mask_and_masked_image(self, image, mask):
|
||||
def prepare_mask_and_masked_image(self, image, mask, height, width):
|
||||
# preprocess image
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
@@ -191,7 +193,7 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image, mask_image
|
||||
image, mask_image, height, width
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
|
||||
@@ -0,0 +1,538 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
import math
|
||||
|
||||
|
||||
class OutpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae_encode: SharkInference,
|
||||
vae: SharkInference,
|
||||
text_encoder: SharkInference,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: SharkInference,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
||||
self.vae_encode = vae_encode
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_and_masked_image(self, image, mask, mask_blur):
|
||||
if mask_blur > 0:
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
|
||||
image = image.resize((512, 512))
|
||||
mask = mask.resize((512, 512))
|
||||
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def get_matched_noise(
|
||||
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
|
||||
):
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_fft[:, :, c] = np.fft.fft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
||||
else: # one channel
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_fft[:, :] = np.fft.fft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
||||
|
||||
return out_fft
|
||||
|
||||
def _ifft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_ifft[:, :, c] = np.fft.ifft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
||||
else: # one channel
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
||||
|
||||
return out_ifft
|
||||
|
||||
def _get_gaussian_window(width, height, std=3.14, mode=0):
|
||||
window_scale_x = float(width / min(width, height))
|
||||
window_scale_y = float(height / min(width, height))
|
||||
|
||||
window = np.zeros((width, height))
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
for y in range(height):
|
||||
fy = (y / height * 2.0 - 1.0) * window_scale_y
|
||||
if mode == 0:
|
||||
window[:, y] = np.exp(-(x**2 + fy**2) * std)
|
||||
else:
|
||||
window[:, y] = (
|
||||
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
|
||||
) ** (std / 3.14)
|
||||
|
||||
return window
|
||||
|
||||
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
|
||||
np_mask_rgb = np.zeros(
|
||||
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
|
||||
)
|
||||
if hardness != 1.0:
|
||||
hardened = np_mask_grey[:] ** hardness
|
||||
else:
|
||||
hardened = np_mask_grey[:]
|
||||
for c in range(3):
|
||||
np_mask_rgb[:, :, c] = hardened[:]
|
||||
return np_mask_rgb
|
||||
|
||||
def _match_cumulative_cdf(source, template):
|
||||
src_values, src_unique_indices, src_counts = np.unique(
|
||||
source.ravel(), return_inverse=True, return_counts=True
|
||||
)
|
||||
tmpl_values, tmpl_counts = np.unique(
|
||||
template.ravel(), return_counts=True
|
||||
)
|
||||
|
||||
# calculate normalized quantiles for each array
|
||||
src_quantiles = np.cumsum(src_counts) / source.size
|
||||
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
|
||||
|
||||
interp_a_values = np.interp(
|
||||
src_quantiles, tmpl_quantiles, tmpl_values
|
||||
)
|
||||
return interp_a_values[src_unique_indices].reshape(source.shape)
|
||||
|
||||
def _match_histograms(image, reference):
|
||||
if image.ndim != reference.ndim:
|
||||
raise ValueError(
|
||||
"Image and reference must have the same number of channels."
|
||||
)
|
||||
|
||||
if image.shape[-1] != reference.shape[-1]:
|
||||
raise ValueError(
|
||||
"Number of channels in the input image and reference image must match!"
|
||||
)
|
||||
|
||||
matched = np.empty(image.shape, dtype=image.dtype)
|
||||
for channel in range(image.shape[-1]):
|
||||
matched_channel = _match_cumulative_cdf(
|
||||
image[..., channel], reference[..., channel]
|
||||
)
|
||||
matched[..., channel] = matched_channel
|
||||
|
||||
matched = matched.astype(np.float64, copy=False)
|
||||
return matched
|
||||
|
||||
width = _np_src_image.shape[0]
|
||||
height = _np_src_image.shape[1]
|
||||
num_channels = _np_src_image.shape[2]
|
||||
|
||||
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
|
||||
img_mask = np_mask_grey > 1e-6
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||
windowed_image = _np_src_image * (
|
||||
1.0 - _get_masked_window_rgb(np_mask_grey)
|
||||
)
|
||||
windowed_image /= np.max(windowed_image)
|
||||
windowed_image += np.average(_np_src_image) * np_mask_rgb
|
||||
|
||||
src_fft = _fft2(
|
||||
windowed_image
|
||||
) # get feature statistics from masked src img
|
||||
src_dist = np.absolute(src_fft)
|
||||
src_phase = src_fft / src_dist
|
||||
|
||||
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
noise_window = _get_gaussian_window(
|
||||
width, height, mode=1
|
||||
) # start with simple gaussian noise
|
||||
noise_rgb = rng.random((width, height, num_channels))
|
||||
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
|
||||
# the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||
noise_rgb *= color_variation
|
||||
for c in range(num_channels):
|
||||
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
|
||||
|
||||
noise_fft = _fft2(noise_rgb)
|
||||
for c in range(num_channels):
|
||||
noise_fft[:, :, c] *= noise_window
|
||||
noise_rgb = np.real(_ifft2(noise_fft))
|
||||
shaped_noise_fft = _fft2(noise_rgb)
|
||||
shaped_noise_fft[:, :, :] = (
|
||||
np.absolute(shaped_noise_fft[:, :, :]) ** 2
|
||||
* (src_dist**noise_q)
|
||||
* src_phase
|
||||
) # perform the actual shaping
|
||||
|
||||
# color_variation
|
||||
brightness_variation = 0.0
|
||||
contrast_adjusted_np_src = (
|
||||
_np_src_image[:] * (brightness_variation + 1.0)
|
||||
- brightness_variation * 2.0
|
||||
)
|
||||
|
||||
shaped_noise = np.real(_ifft2(shaped_noise_fft))
|
||||
shaped_noise -= np.min(shaped_noise)
|
||||
shaped_noise /= np.max(shaped_noise)
|
||||
shaped_noise[img_mask, :] = _match_histograms(
|
||||
shaped_noise[img_mask, :] ** 1.0,
|
||||
contrast_adjusted_np_src[ref_mask, :],
|
||||
)
|
||||
shaped_noise = (
|
||||
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||
)
|
||||
|
||||
matched_noise = shaped_noise[:]
|
||||
|
||||
return np.clip(matched_noise, 0.0, 1.0)
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
is_left,
|
||||
is_right,
|
||||
is_top,
|
||||
is_bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings from prompts
|
||||
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
process_width = width
|
||||
process_height = height
|
||||
left = pixels if is_left else 0
|
||||
right = pixels if is_right else 0
|
||||
up = pixels if is_top else 0
|
||||
down = pixels if is_bottom else 0
|
||||
target_w = math.ceil((image.width + left + right) / 64) * 64
|
||||
target_h = math.ceil((image.height + up + down) / 64) * 64
|
||||
|
||||
if left > 0:
|
||||
left = left * (target_w - image.width) // (left + right)
|
||||
if right > 0:
|
||||
right = target_w - image.width - left
|
||||
if up > 0:
|
||||
up = up * (target_h - image.height) // (up + down)
|
||||
if down > 0:
|
||||
down = target_h - image.height - up
|
||||
|
||||
def expand(
|
||||
init_img,
|
||||
expand_pixels,
|
||||
is_left=False,
|
||||
is_right=False,
|
||||
is_top=False,
|
||||
is_bottom=False,
|
||||
):
|
||||
is_horiz = is_left or is_right
|
||||
is_vert = is_top or is_bottom
|
||||
pixels_horiz = expand_pixels if is_horiz else 0
|
||||
pixels_vert = expand_pixels if is_vert else 0
|
||||
|
||||
res_w = init_img.width + pixels_horiz
|
||||
res_h = init_img.height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(
|
||||
init_img,
|
||||
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
|
||||
)
|
||||
|
||||
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(msk)
|
||||
draw.rectangle(
|
||||
(
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
msk.width - expand_pixels - mask_blur
|
||||
if is_right
|
||||
else res_w,
|
||||
msk.height - expand_pixels - mask_blur
|
||||
if is_bottom
|
||||
else res_h,
|
||||
),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
|
||||
noised = self.get_matched_noise(
|
||||
np_image, np_mask, noise_q, color_variation
|
||||
)
|
||||
output_image = Image.fromarray(
|
||||
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
|
||||
mode="RGB",
|
||||
)
|
||||
|
||||
target_width = (
|
||||
min(width, init_img.width + pixels_horiz)
|
||||
if is_horiz
|
||||
else img.width
|
||||
)
|
||||
target_height = (
|
||||
min(height, init_img.height + pixels_vert)
|
||||
if is_vert
|
||||
else img.height
|
||||
)
|
||||
crop_region = (
|
||||
0 if is_left else output_image.width - target_width,
|
||||
0 if is_top else output_image.height - target_height,
|
||||
target_width if is_left else output_image.width,
|
||||
target_height if is_top else output_image.height,
|
||||
)
|
||||
mask_to_process = msk.crop(crop_region)
|
||||
image_to_process = output_image.crop(crop_region)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image_to_process, mask_to_process, mask_blur
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
res_img = all_imgs[0].resize(
|
||||
(image_to_process.width, image_to_process.height)
|
||||
)
|
||||
output_image.paste(
|
||||
res_img,
|
||||
(
|
||||
0 if is_left else output_image.width - res_img.width,
|
||||
0 if is_top else output_image.height - res_img.height,
|
||||
),
|
||||
)
|
||||
output_image = output_image.crop((0, 0, res_w, res_h))
|
||||
|
||||
return output_image
|
||||
|
||||
img = image.resize((width, height))
|
||||
if left > 0:
|
||||
img = expand(img, left, is_left=True)
|
||||
if right > 0:
|
||||
img = expand(img, right, is_right=True)
|
||||
if up > 0:
|
||||
img = expand(img, up, is_top=True)
|
||||
if down > 0:
|
||||
img = expand(img, down, is_bottom=True)
|
||||
|
||||
return [img]
|
||||
@@ -201,6 +201,7 @@ class StableDiffusionPipeline:
|
||||
width: int,
|
||||
use_base_vae: bool,
|
||||
use_tuned: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
):
|
||||
if import_mlir:
|
||||
mlir_import = SharkifyStableDiffusionModel(
|
||||
@@ -214,8 +215,13 @@ class StableDiffusionPipeline:
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
@@ -223,7 +229,11 @@ class StableDiffusionPipeline:
|
||||
clip, unet, vae = mlir_import()
|
||||
return cls(vae, clip, get_tokenizer(), unet, scheduler)
|
||||
try:
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
return cls(
|
||||
get_vae_encode(),
|
||||
get_vae(),
|
||||
@@ -248,8 +258,13 @@ class StableDiffusionPipeline:
|
||||
width=width,
|
||||
use_base_vae=use_base_vae,
|
||||
use_tuned=use_tuned,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
|
||||
if cls.__name__ in [
|
||||
"Image2ImagePipeline",
|
||||
"InpaintPipeline",
|
||||
"OutpaintPipeline",
|
||||
]:
|
||||
clip, unet, vae, vae_encode = mlir_import()
|
||||
return cls(
|
||||
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
|
||||
|
||||
@@ -70,7 +70,7 @@ def load_winograd_configs():
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_name = f"{args.annotation_model}_winograd_{device}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
winograd_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading Winograd config file from ", winograd_config_dir)
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
return winograd_config_dir
|
||||
@@ -113,7 +113,7 @@ def load_lower_configs():
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
@@ -136,11 +136,13 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
|
||||
|
||||
if args.save_annotation:
|
||||
if model_name.split("_")[-1] != "tuned":
|
||||
out_file_path = (
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_tuned_torch.mlir"
|
||||
)
|
||||
else:
|
||||
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
|
||||
out_file_path = os.path.join(
|
||||
args.annotation_output, model_name + "_torch.mlir"
|
||||
)
|
||||
with open(out_file_path, "w") as f:
|
||||
f.write(str(winograd_model))
|
||||
f.close()
|
||||
|
||||
@@ -35,12 +35,6 @@ p.add_argument(
|
||||
help="Path to the image input for img2img/inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
@@ -97,6 +91,75 @@ p.add_argument(
|
||||
default=0.8,
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--pixels",
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 256, 8),
|
||||
help="Number of expended pixels for one direction for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_blur",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 64),
|
||||
help="Number of blur pixels for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0)",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
@@ -193,6 +256,13 @@ p.add_argument(
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -27,11 +27,16 @@ def resource_path(relative_path):
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import txt2img_web, img2img_web
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
img2img_web,
|
||||
inpaint_web,
|
||||
outpaint_web,
|
||||
)
|
||||
|
||||
sd_web = gr.TabbedInterface(
|
||||
[txt2img_web, img2img_web],
|
||||
["Text-to-Image", "Image-to-Image"],
|
||||
[txt2img_web, img2img_web, inpaint_web, outpaint_web],
|
||||
["Text-to-Image", "Image-to-Image", "Inpainting", "Outpainting"],
|
||||
css=dark_theme,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_web
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import img2img_web
|
||||
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_web
|
||||
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_web
|
||||
|
||||
@@ -145,11 +145,14 @@
|
||||
}
|
||||
|
||||
/* SHARK theme */
|
||||
body {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container .contain {
|
||||
.gradio-container {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
}
|
||||
@@ -158,10 +161,6 @@
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
.gradio-container {
|
||||
background-color: var(--color-background-primary);
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: var(--size-5) !important;
|
||||
|
||||
224
apps/stable_diffusion/web/ui/inpaint_ui.py
Normal file
224
apps/stable_diffusion/web/ui/inpaint_ui.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import inpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(
|
||||
label="Masked Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="filepath",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=inpaint_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
260
apps/stable_diffusion/web/ui/outpaint_ui.py
Normal file
260
apps/stable_diffusion/web/ui/outpaint_ui.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import outpaint_inf
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
ckpt_path = (
|
||||
Path(args.ckpt_dir)
|
||||
if args.ckpt_dir
|
||||
else Path(Path.cwd(), "models")
|
||||
)
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
types = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
) # the tuple of file types
|
||||
ckpt_files = ["None"]
|
||||
for extn in types:
|
||||
files = glob.glob(os.path.join(ckpt_path, extn))
|
||||
ckpt_files.extend(files)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {ckpt_path})",
|
||||
value=args.ckpt_loc if args.ckpt_loc else "None",
|
||||
choices=ckpt_files
|
||||
+ [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
],
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
init_image = gr.Image(label="Input Image", type="filepath")
|
||||
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="PNDM",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"DPMSolverMultistep",
|
||||
"EulerAncestralDiscrete",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
pixels = gr.Slider(
|
||||
8,
|
||||
256,
|
||||
value=args.pixels,
|
||||
step=8,
|
||||
label="Pixels to expand",
|
||||
)
|
||||
mask_blur = gr.Slider(
|
||||
0,
|
||||
64,
|
||||
value=args.mask_blur,
|
||||
step=1,
|
||||
label="Mask blur",
|
||||
)
|
||||
with gr.Row():
|
||||
directions = gr.CheckboxGroup(
|
||||
label="Outpainting direction",
|
||||
choices=["left", "right", "up", "down"],
|
||||
value=["left", "right", "up", "down"],
|
||||
)
|
||||
with gr.Row():
|
||||
noise_q = gr.Slider(
|
||||
0.0,
|
||||
4.0,
|
||||
value=1.0,
|
||||
step=0.01,
|
||||
label="Fall-off exponent (lower=higher detail)",
|
||||
)
|
||||
color_variation = gr.Slider(
|
||||
0.0,
|
||||
1.0,
|
||||
value=0.05,
|
||||
step=0.01,
|
||||
label="Color variation",
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 786, value=args.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 786, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=20, step=1, label="Steps"
|
||||
)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => Math.floor(Math.random() * 4294967295)",
|
||||
)
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(grid=[2])
|
||||
std_output = gr.Textbox(
|
||||
value="Nothing to show.",
|
||||
lines=1,
|
||||
show_label=False,
|
||||
)
|
||||
output_dir = args.output_dir if args.output_dir else Path.cwd()
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
output_loc = gr.Textbox(
|
||||
label="Saving Images at",
|
||||
value=output_dir,
|
||||
interactive=False,
|
||||
)
|
||||
kwargs = dict(
|
||||
fn=outpaint_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
init_image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
directions,
|
||||
noise_q,
|
||||
color_variation,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
custom_model,
|
||||
hf_model_id,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
],
|
||||
outputs=[gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
)
|
||||
|
||||
prompt.submit(**kwargs)
|
||||
negative_prompt.submit(**kwargs)
|
||||
stable_diffusion.click(**kwargs)
|
||||
10
conftest.py
10
conftest.py
@@ -60,3 +60,13 @@ def pytest_addoption(parser):
|
||||
default="gs://shark_tank/latest",
|
||||
help="URL to bucket from which to download SHARK tank artifacts. Default is gs://shark_tank/latest",
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark_dispatches",
|
||||
default=None,
|
||||
help="Benchmark individual dispatch kernels produced by IREE compiler. Use 'All' for all, or specific dispatches e.g. '0 1 2 10'",
|
||||
)
|
||||
parser.addoption(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="./temp_dispatch_benchmarks",
|
||||
help="Directory in which dispatch benchmarks are saved.",
|
||||
)
|
||||
|
||||
@@ -162,13 +162,13 @@ def save_tf_model(tf_model_list):
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
input,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,16 @@ from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
|
||||
pix2pix_file = Path(
|
||||
get_python_lib()
|
||||
+ "/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py"
|
||||
)
|
||||
if pix2pix_file.exists():
|
||||
print("Removing..%s", pix2pix_file)
|
||||
pix2pix_file.unlink()
|
||||
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
|
||||
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy==1.22.4
|
||||
numpy>1.22.4
|
||||
torchvision
|
||||
pytorch-triton
|
||||
tabulate
|
||||
@@ -15,8 +15,8 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow==2.10.1
|
||||
keras==2.10
|
||||
tensorflow>=2.10.1
|
||||
keras>=2.10
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
|
||||
@@ -16,7 +16,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers @ git+https://github.com/huggingface/diffusers@4c52982a0be7dd850fb9eac55b11509846e4bbe6
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@2226767529c805d7997d1a9f218437f2c7fb65e1
|
||||
scipy
|
||||
ftfy
|
||||
gradio
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
.PARAMETER update-src
|
||||
updates to latest version from git .\source
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update all dependencies
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
@@ -26,12 +26,6 @@
|
||||
.OUTPUTS
|
||||
None
|
||||
|
||||
.NOTES
|
||||
Version 1.0
|
||||
Author powderluv, xzuyn
|
||||
Creation Date 2023-02-17
|
||||
PurposeChange Initial script development
|
||||
|
||||
#>
|
||||
|
||||
param([string]$arguments)
|
||||
@@ -56,18 +50,6 @@ if ($arguments -eq "--force"){
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#Write-Host "Installing python"
|
||||
|
||||
#Start-Process winget install Python.Python.3.10 '/quiet InstallAllUsers=1 PrependPath=1' -wait -NoNewWindow
|
||||
|
||||
#Write-Host "python installation completed successfully"
|
||||
|
||||
#Write-Host "Reload environment variables"
|
||||
#$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
#Write-Host "Reloaded environment variables"
|
||||
|
||||
|
||||
# redirect stderr into stdout
|
||||
$p = &{python -V} 2>&1
|
||||
# check if an ErrorRecord was returned
|
||||
@@ -78,25 +60,36 @@ $version = if($p -is [System.Management.Automation.ErrorRecord])
|
||||
}
|
||||
else
|
||||
{
|
||||
# otherwise return as is
|
||||
$p
|
||||
# otherwise return complete Python list
|
||||
$ErrorActionPreference = 'SilentlyContinue'
|
||||
$PyVer = py --list
|
||||
}
|
||||
|
||||
Write-Host "Python version found is"
|
||||
Write-Host $p
|
||||
if ($p -notlike "*3.11*")
|
||||
# deactivate any activated venvs
|
||||
if ($PyVer -like "*venv*")
|
||||
{
|
||||
deactivate # make sure we don't update the wrong venv
|
||||
$PyVer = py --list # update list
|
||||
}
|
||||
|
||||
Write-Host "Python versions found are"
|
||||
Write-Host ($PyVer | Out-String) # formatted output with line breaks
|
||||
if (!($PyVer.length -ne 0)) {$p} # return Python --version String if py.exe is unavailable
|
||||
if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any list
|
||||
{
|
||||
Write-Host "Please install Python 3.11 and try again"
|
||||
break
|
||||
}
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
py -3.11 -m venv .\shark.venv\
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if (!($PyVer.length -ne 0)) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
@@ -98,7 +98,7 @@ if [[ -z "${USE_IREE}" ]]; then
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://iree-org.github.io/iree/pip-release-links.html"
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
@@ -112,7 +112,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://iree-org.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
@@ -129,7 +129,7 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
TV_VERSION=${TV_VER:9:18}
|
||||
$PYTHON -m pip uninstall -y torch torchvision
|
||||
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
|
||||
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch + cu117."
|
||||
else
|
||||
|
||||
@@ -9,9 +9,11 @@
|
||||
# -d, --download: set to true if you want to redownload the mlir files
|
||||
# -t --token_count: the number of tokens you want to generate
|
||||
# -pr --prompt: the prompt you want to feed to the model
|
||||
# -m --model_namme: the name of the model, e.g. bloom-560m
|
||||
#####################################################################################
|
||||
|
||||
import os
|
||||
import io
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
@@ -22,14 +24,18 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
from cuda.cudart import cudaSetDevice
|
||||
import json
|
||||
import urllib.request
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
from transformers import (
|
||||
BloomTokenizerFast,
|
||||
BloomForSequenceClassification,
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
@@ -47,16 +53,22 @@ class ShardedBloom:
|
||||
self.layers_initialized = False
|
||||
|
||||
self.src_folder = src_folder
|
||||
self.n_embed = config["n_embed"]
|
||||
try:
|
||||
self.n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
self.n_embed = config["hidden_size"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.n_layer = config["n_layer"]
|
||||
self.n_head = config["num_attention_heads"]
|
||||
try:
|
||||
self.n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
self.n_head = config["n_head"]
|
||||
|
||||
def _init_layer(self, layer_name, device, replace, device_idx):
|
||||
if replace or not os.path.exists(
|
||||
f"{self.src_folder}/{layer_name}.vmfb"
|
||||
):
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir")
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8")
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
@@ -292,90 +304,352 @@ def _prepare_attn_mask(
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def download_560m(destination_folder):
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_0.mlir", destination_folder
|
||||
f"https://{model_name}/config.json", destination_folder
|
||||
)
|
||||
f = open(f"{destination_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
n_blocks = config["n_layer"]
|
||||
download_public_file(
|
||||
f"https://{model_name}/lm_head.mlir", destination_folder
|
||||
)
|
||||
download_public_file(f"https://{model_name}/ln_f.mlir", destination_folder)
|
||||
download_public_file(
|
||||
f"https://{model_name}/word_embeddings.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_1.mlir", destination_folder
|
||||
f"https://{model_name}/word_embeddings_layernorm.mlir",
|
||||
destination_folder,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_2.mlir", destination_folder
|
||||
f"https://{model_name}/tokenizer.json", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_3.mlir", destination_folder
|
||||
|
||||
for i in range(n_blocks):
|
||||
download_public_file(
|
||||
f"https://{model_name}/bloom_block_{i}.mlir", destination_folder
|
||||
)
|
||||
|
||||
|
||||
def compile_embeddings(embeddings_layer, input_ids, path):
|
||||
input_ids_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_4.mlir", destination_folder
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer,
|
||||
(input_ids_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_5.mlir", destination_folder
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_word_embeddings_layernorm(
|
||||
embeddings_layer_layernorm, embeds, path
|
||||
):
|
||||
embeds_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
embeds, dynamic_axes=[1]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_6.mlir", destination_folder
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer_layernorm,
|
||||
(embeds_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_7.mlir", destination_folder
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def compile_to_mlir(
|
||||
bblock,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
alibi=None,
|
||||
block_index=0,
|
||||
path=".",
|
||||
):
|
||||
fx_g = make_fx(
|
||||
bblock,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=False,
|
||||
)(hidden_states, alibi, attention_mask)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_8.mlir", destination_folder
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_9.mlir", destination_folder
|
||||
alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2])
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
alibi_placeholder,
|
||||
attention_mask_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_10.mlir", destination_folder
|
||||
|
||||
module_placeholder = module
|
||||
module_context = module_placeholder.context
|
||||
|
||||
def check_valid_line(line, line_n, mlir_file_len):
|
||||
if "private" in line:
|
||||
return False
|
||||
if "attributes" in line:
|
||||
return False
|
||||
if mlir_file_len - line_n == 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
mlir_file_len = len(str(module).split("\n"))
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "17x" in line:
|
||||
line = re.sub("17x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi eq" in line:
|
||||
line = re.sub("c17", "dim", line)
|
||||
if " 17," in line:
|
||||
line = re.sub(" 17,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = "\n".join(
|
||||
[
|
||||
remove_constant_dim(line)
|
||||
for line, line_n in zip(
|
||||
str(module).split("\n"), range(mlir_file_len)
|
||||
)
|
||||
if check_valid_line(line, line_n, mlir_file_len)
|
||||
]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_11.mlir", destination_folder
|
||||
|
||||
module = module_placeholder.parse(module, context=module_context)
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_ln_f(ln_f, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_12.mlir", destination_folder
|
||||
module = torch_mlir.compile(
|
||||
ln_f,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_13.mlir", destination_folder
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_lm_head(lm_head, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_14.mlir", destination_folder
|
||||
module = torch_mlir.compile(
|
||||
lm_head,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_15.mlir", destination_folder
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def create_mlirs(destination_folder, model_name):
|
||||
model_config = "bigscience/" + model_name
|
||||
sample_input_ids = torch.ones([1, 17], dtype=torch.int64)
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json",
|
||||
filename=f"{destination_folder}/config.json",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_16.mlir", destination_folder
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json",
|
||||
filename=f"{destination_folder}/tokenizer.json",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_17.mlir", destination_folder
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BloomForCausalLM.from_pretrained(model_config)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
class HuggingFaceBlock(torch.nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.model = block
|
||||
|
||||
def forward(self, tokens, alibi, attention_mask):
|
||||
output = self.model(
|
||||
hidden_states=tokens,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
model = HuggingFaceLanguage()
|
||||
|
||||
compile_embeddings(
|
||||
model.model.transformer.word_embeddings,
|
||||
sample_input_ids,
|
||||
f"{destination_folder}/word_embeddings.mlir",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_18.mlir", destination_folder
|
||||
|
||||
inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids)
|
||||
|
||||
compile_word_embeddings_layernorm(
|
||||
model.model.transformer.word_embeddings_layernorm,
|
||||
inputs_embeds,
|
||||
f"{destination_folder}/word_embeddings_layernorm.mlir",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_19.mlir", destination_folder
|
||||
|
||||
hidden_states = model.model.transformer.word_embeddings_layernorm(
|
||||
inputs_embeds
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_20.mlir", destination_folder
|
||||
|
||||
input_shape = sample_input_ids.size()
|
||||
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
past_key_values_length = 0
|
||||
past_key_values = tuple([None] * len(model.model.transformer.h))
|
||||
|
||||
attention_mask = torch.ones(
|
||||
(hidden_states.shape[0], current_sequence_length), device="cpu"
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_21.mlir", destination_folder
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
model.model.transformer.n_head,
|
||||
hidden_states.dtype,
|
||||
"cpu",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_22.mlir", destination_folder
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/bloom_block_23.mlir", destination_folder
|
||||
|
||||
head_mask = model.model.transformer.get_head_mask(
|
||||
None, model.model.transformer.config.n_layer
|
||||
)
|
||||
download_public_file("https://bloom-560m/config.json", destination_folder)
|
||||
download_public_file("https://bloom-560m/lm_head.mlir", destination_folder)
|
||||
download_public_file("https://bloom-560m/ln_f.mlir", destination_folder)
|
||||
download_public_file(
|
||||
"https://bloom-560m/word_embeddings.mlir", destination_folder
|
||||
output_attentions = model.model.transformer.config.output_attentions
|
||||
|
||||
all_hidden_states = ()
|
||||
|
||||
for i, (block, layer_past) in enumerate(
|
||||
zip(model.model.transformer.h, past_key_values)
|
||||
):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
proxy_model = HuggingFaceBlock(block)
|
||||
|
||||
compile_to_mlir(
|
||||
proxy_model,
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=True,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
block_index=i,
|
||||
path=f"{destination_folder}/bloom_block_{i}.mlir",
|
||||
)
|
||||
|
||||
compile_ln_f(
|
||||
model.model.transformer.ln_f,
|
||||
hidden_states,
|
||||
f"{destination_folder}/ln_f.mlir",
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/word_embeddings_layernorm.mlir", destination_folder
|
||||
)
|
||||
download_public_file(
|
||||
"https://bloom-560m/tokenizer.json", destination_folder
|
||||
hidden_states = model.model.transformer.ln_f(hidden_states)
|
||||
compile_lm_head(
|
||||
model.model.lm_head,
|
||||
hidden_states,
|
||||
f"{destination_folder}/lm_head.mlir",
|
||||
)
|
||||
|
||||
|
||||
@@ -387,6 +661,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-c", "--recompile", default=False, type=bool)
|
||||
parser.add_argument("-d", "--download", default=False, type=bool)
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument("-m", "--model_name", default="bloom-560m")
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
@@ -399,8 +674,10 @@ if __name__ == "__main__":
|
||||
|
||||
if args.device == "cuda" and args.device_list is not None:
|
||||
IS_CUDA = True
|
||||
from cuda.cudart import cudaSetDevice
|
||||
if args.download:
|
||||
download_560m(args.model_path)
|
||||
# download_model(args.model_path, args.model_name)
|
||||
create_mlirs(args.model_path, args.model_name)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
|
||||
@@ -139,9 +139,14 @@ def run_benchmark_module(benchmark_cl):
|
||||
benchmark_path
|
||||
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_result = run_cmd(" ".join(benchmark_cl))
|
||||
print(bench_result)
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_result)
|
||||
time = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
return 1.0 / (time * 0.001)
|
||||
|
||||
@@ -99,6 +99,7 @@ else:
|
||||
print(
|
||||
f"shark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
|
||||
)
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
|
||||
|
||||
# Checks whether the directory and files exists.
|
||||
|
||||
@@ -137,6 +137,19 @@ class SharkModuleTester:
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.force_update_tank = self.update_tank
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
if self.benchmark_dispatches is not None:
|
||||
_m = self.config["model_name"].split("/")
|
||||
_m.extend([self.config["framework"], str(dynamic), device])
|
||||
_m = "_".join(_m)
|
||||
shark_args.dispatch_benchmarks_dir = os.path.join(
|
||||
self.dispatch_benchmarks_dir,
|
||||
_m,
|
||||
)
|
||||
if not os.path.exists(self.dispatch_benchmarks_dir):
|
||||
os.mkdir(self.dispatch_benchmarks_dir)
|
||||
if not os.path.exists(shark_args.dispatch_benchmarks_dir):
|
||||
os.mkdir(shark_args.dispatch_benchmarks_dir)
|
||||
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
|
||||
".use-iree"
|
||||
):
|
||||
@@ -278,6 +291,12 @@ class SharkModuleTest(unittest.TestCase):
|
||||
"update_tank"
|
||||
)
|
||||
self.module_tester.tank_url = self.pytestconfig.getoption("tank_url")
|
||||
self.module_tester.benchmark_dispatches = self.pytestconfig.getoption(
|
||||
"benchmark_dispatches"
|
||||
)
|
||||
self.module_tester.dispatch_benchmarks_dir = (
|
||||
self.pytestconfig.getoption("dispatch_benchmarks_dir")
|
||||
)
|
||||
|
||||
if config["xfail_cpu"] == "True" and device == "cpu":
|
||||
pytest.xfail(reason=config["xfail_reason"])
|
||||
|
||||
Reference in New Issue
Block a user