Compare commits

..

69 Commits

Author SHA1 Message Date
Ean Garvey
fd3fdeef07 Update vulkan_utils.py 2024-01-09 12:39:11 -06:00
Ean Garvey
80906155f8 Revert to 20231212 2024-01-09 12:31:38 -06:00
Ean Garvey
1c094d96eb Update vulkan_utils.py 2024-01-09 12:17:37 -06:00
Ean Garvey
94805aa3c6 Update setup_venv.sh 2024-01-08 23:01:51 -06:00
Ean Garvey
cfb494edff Update setup_venv.ps1 2024-01-08 23:01:22 -06:00
Ean Garvey
10294bef33 Update setup.py 2024-01-08 23:01:02 -06:00
Ean Garvey
8051b33005 Update pyproject.toml 2024-01-08 23:00:42 -06:00
Ean Garvey
10af1f25c7 Update setup.py 2024-01-08 22:38:05 -06:00
Ean Garvey
4305153d21 Update pyproject.toml 2024-01-08 22:37:49 -06:00
Ean Garvey
ab5639ff66 Update pyproject.toml 2024-01-08 21:13:09 -06:00
Ean Garvey
4cf0383768 Update setup.py 2024-01-08 21:12:43 -06:00
Ean Garvey
01845d593d Remove linux builds from 1.0 nightly workflow
It seems that the VMs used for these workflows are no longer available. Removing linux builds since publishing .exe is sufficient for the one-shot nightly workflows we trigger for SHARK-1.0
2024-01-08 18:28:09 -06:00
Ean Garvey
5b15ceee35 Move IREE pins for linux. 2024-01-08 18:23:00 -06:00
Ean Garvey
04b21295ee Move IREE pins for windows. 2024-01-08 18:22:09 -06:00
Stefan Kapusniak
dda7e8a163 (Shark-1) UI: Fix 'keyword argument' error in txt2img (#2058)
* Removes the incorrect valid_base_models keywoard argument left in text2img_inf
2024-01-08 16:47:15 -06:00
Stefan Kapusniak
7fdd1952ae (Shark 1.0) UI/SD UX improvements for SDXL (#2057)
* SDXL Tab
  * Filter VAEs in dropdown in the same manner as models
  * Set default VAE selection to 'madebyollin/sdxl-vae-fp16-fix'
  * Set default image size to 768x768 to match current Vulkan constraints
* SharkifySDModel Base Unet Model Determination
  * Alway use the model_to_run as the base model for unet, if it is in
base_model.json, instead of potentially trying to compile for other base
models.
  * Allow SharkSDPipelines to define a 'favor_base_models' @classmethod,
answering a list of sane base model names for the pipeline. Exclude base
models not in that list from compilation attempts when trying to determine
a base unet model.
  *  Add a 'favor_base_models' method for both Normal and SDXL Txt2Img
Pipelines. Define the method as answering 'None' in the base class.
2024-01-06 11:59:18 -06:00
Ean Garvey
0a6f6fad86 (1.0) Fix non-square controlnet dimensions. (#2056)
* Fix non-square controlnet dims

* Hide batch size slider in txt2img
2024-01-04 21:33:55 -06:00
Ean Garvey
6853a33728 Pin iree versions and fix quant matmul flags. (#2055)
Restricts quantized matmul reassociation flags to cpu compiles of llama2 and pins IREE versions for shark 1.0
2024-01-04 14:22:54 -06:00
Stefan Kapusniak
3887d83f5d (Shark 1.0) UI: Upgrade to gradio 4.12.0 and fix breakage (#2051)
* (Shark 1.0) UI: Upgrade to gradio 4.12.0 and fix breakage

* Upgrade Shark 1.0 to gradio 4.12.0
* Add javascript workaround for gradio currently ignoring @media rules in custom CSS. This fixes UI not showing at full width on desktop (>1536px width).

* (Shark 1.0) UI: Re-enable gallery download buttons

* Re-enable gallery download buttons, as this is now working again in Gradio 4.12
2024-01-03 19:00:08 -06:00
Stefan Kapusniak
8d9b5b3afa SD/UI: Merge Lora Selection Boxes, Add LoRA Strength (#2052)
* Merges LoRA selection in the UI into a single selection, rather than
one for LoRAs under ./models and another for Hugging Face Id
* Add LoRA strength to UI and pipeline parameters.
* Add a `--lora_strength` command line argument.
* Bake LoRA strength into .vmfb naming when a LoRA is specified.
* Use LoRA embedded alpha values and (up tensor dimension *
LoRA strength) for final alpha when applying LoRA weights rather
than a hardcoded value of 0.75
* Adds additional cases to the LoRA weight application that are
present for weight application in the Kohya scripts.
* Include lora strength when reading and writing png metadata.
* Allow lora_strength to be set above 1.0 in the UI, so similar effects
to the prior (overdriven alpha) implementation can be obtained.
2024-01-03 18:59:47 -06:00
xzuyn
16c03e4b44 fix for TypeError: Image2ImagePipeline.generate_images() missing 1 required positional argument: 'images' (#2053) 2024-01-03 18:54:54 -06:00
Stefan Kapusniak
17dab8334d Setup a separate .shark1.venv for Shark-1.0 (#2043)
* Updates setup_venv.ps1 to create and use ./shark1.venv/
* Updates setup_venv.sh to default to creating ./shark1.venv/
2023-12-16 23:18:29 -06:00
Stefan Kapusniak
f692a012e1 UI: Fixes for Gradio 4.7.1/4.8.0 update (#2024)
* Upgrade Gradio pin from 4.7.1 to 4.80.
* Make Nod AI logos visible again.
* Remove image toolbars from png import boxes.
* Set Input Images on img2img, outpaint and upscaler tabs to be upload
only.
* Change Image control to an ImageEditor control for masking on the
inpaint tab. Remove previous height restriction as this hides the
editing controls.
* Move Input Image/Masked Image on img2img, inpaint, outpaint and
upscaler tabs to be the first control on their tabs.
* Remove download buttons from all galleries as they download some
html rather the image (gradio issue #6595)
* Remove add new row and column from Output Gallery parameters
dataframe.
* Add partial workaround for not being able to select text in the Output
Gallery Gallery parameters dataframe (gradio issue #6086 )
* Fix uglified formatting of subdirectory selection dropown, refresh
button, and open folder buttons on the Output Gallery tab.
* Force Output Gallery to use the full width of the Gallery control
for the preview overlay when an image is selected, rather than
an overlay the width of the selected image.
* Fix sendto buttons.
* Reset Inpaint ImageEditor control with the Mask Layer after generation
is complete, as it gets lost if the image was sent to the tab from
another tab rather than being uploaded. Also rework queuing and
progress rendering along this codepath. This doesn't solve the
underlying problem of the Mask Layer being removed, but does get inpaint
fully working with the Gradio update.
2023-12-14 14:56:37 -06:00
Vivek Khandelwal
3cc643b2de Add support for StableLM-3B model (#2019)
* Add support for StableLM-3B model

* Add support for Quantized StableLM-3B model

* Update stablelm_pipeline.py
2023-12-12 22:39:50 +05:30
Phaneesh Barwaria
bf70e80d20 vulkan device id fix (#2028) 2023-12-08 19:00:26 -06:00
Ean Garvey
7159698496 (Studio) Fix controlnet switching. (#2026)
* Fix controlnet switching.

* Fix txt2img + control adapters
2023-12-07 00:52:36 -06:00
gpetters94
7e12d1782a Fix stencil pipline to use input image (#2027) 2023-12-07 00:25:18 -06:00
Ean Garvey
bb5f133e1c Many UI fixes and controlnet impovements (#2025)
* multi-controlnet UI and perf fixes

* Controlnet fixes
2023-12-06 20:10:06 -06:00
Richard Pastirčák
3af0c6c658 #1843 - Add Export Default settings button (#2016)
* #1843 - Add Export Default settings button

* #1843 reformating units test

---------

Co-authored-by: Richard Pastirčák <richard.pastircak@student.tuke.sk>
2023-12-06 14:58:17 -06:00
Ean Garvey
3322b7264f (vicuna.py) Move enable_tracy_tracing outside of BenchmarkRunInfo (#2011) 2023-12-06 14:57:32 -06:00
Ean Garvey
eeb7bdd143 Fix nodlogo (#2023) 2023-12-06 14:57:16 -06:00
Ean Garvey
2d6f48821d Fix SharkEulerDiscrete (#2022) 2023-12-06 12:25:06 -06:00
Gaurav Shukla
c74b55f24e [ui] Add UI for sharding
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-12-06 17:25:49 +05:30
Elias Joseph
1a723645fb finilized fixes for sharded llama2 2023-12-06 15:35:29 +05:30
Eliasj42
dfdd3b1f78 improved sharded performance and fixed issue with lmhead on rocm (#2008)
* improved sharded performance and fixed issue with lmhead on rocm

* mmap shards + disable sharing of device arrays across devices

* fix device_idx for non-layer vmfbs

* fix time calc for sharded

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-12-05 11:53:44 -08:00
Ean Garvey
6384780d16 Fixes to llama2 cpu compilation and studio UI, schedulers (#2013)
* Fix some issues with defaults

Fixes to llama2 cpu compilation (turns off data tiling for old argmax
mode)

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
2023-12-05 11:19:19 -05:00
gpetters94
db0c53ae59 Fix zoedepth (#2010) 2023-12-05 04:31:50 -05:00
Ean Garvey
ce9ce3a7c8 (SD) Fix schedulers and multi-controlnet. (#2006)
* (SD) Fixes schedulers if recieving noise preds as numpy arrays

* Fix schedulers and stencil name

* Multicontrolnet fixes
2023-12-05 03:29:18 -06:00
Ean Garvey
d72da3801f (Studio) Update gradio and multicontrolnet UI. (#2001)
* (Studio) Update gradio and multicontrolnet UI.

* Fixes for outputgallery, exe build

* Fix image return types.

* Update Gradio to 4.7.1

* Fix send buttons and hiresfix

* Various bugfixes and SDXL additions.

* More UI fixes and txt2img_sdxl presets.

*enable SDXL-Turbo and custom models, custom VAE for sdxl

* img2img ui tweaks
2023-12-04 12:37:51 -06:00
Eliasj42
9c50edc664 fixed functionality of sharded vicuna/llama2 (#1982)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-12-04 09:11:52 -08:00
Abhishek Varma
a1b7110550 [SDXL] Add SDXL pipeline to SHARK (#1941)
* [SDXL] Add SDXL pipeline to SHARK

-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-12-02 03:15:15 -06:00
gpetters94
ff15fd74f6 Add multicontrolnet (#1958) 2023-12-01 13:51:20 -06:00
gpetters94
552b2c3ee3 Add controlmode (#1957) 2023-12-01 13:04:47 -06:00
Ean Garvey
795fc33001 Update default compilation flags for data tiling. (#2000)
* Update default CPU compilation flags.

c5a6cdc8dd

52eb7e9b82

tweak CPU iree-compile flags to match upstream changes.

* Add an option for data tiling on SD models.
2023-11-30 17:05:37 -06:00
gpetters94
2910841fe6 Fix an importer issue on Linux (#1986) 2023-11-30 10:50:33 -06:00
Vivek Khandelwal
396a054856 Fix Sharded Falcon-180b 2023-11-30 21:51:57 +05:30
Vivek Khandelwal
5c66948d4f Fix unsharded Falcon pipeline 2023-11-30 21:51:57 +05:30
Ean Garvey
ed3dda94c0 Cleanup xfails in pytest suite. (#1995) 2023-11-29 23:16:15 -06:00
Quinn Dawkins
d31d28b082 [SD] Add flag to collapse reduction dims pre dispatch formation (#1999) 2023-11-30 00:09:17 -05:00
Evan Ruttenberg
78c607e1d3 Fix typo in default_rocm_arch (#1998) 2023-11-29 20:40:56 -05:00
Vivek Khandelwal
666e601dd9 Remove sharding support for non-180B falcon variants 2023-11-27 13:45:13 +05:30
Vivek Khandelwal
ca58908e5b Add Falcon-GPTQ Support for 2-way sharding 2023-11-27 13:45:13 +05:30
Jakub Kuderski
1f5b39f56e [vicuna.py] Add option to enable tracing (#1993)
This makes the program wait for tracy profiler to connect before exiting
and flush profiling data after each token.

I don't know how to select the tracy iree-runtime variant
programatically -- instead, print an error and exit.
2023-11-24 12:25:03 -08:00
Jakub Kuderski
2da31c4109 [vicuna.py] Rework benchmark statistics calculation (#1992)
- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
2023-11-23 12:04:03 -05:00
Ean Garvey
da50a16242 Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989) 2023-11-23 01:01:41 -06:00
Stefan Kapusniak
ce38d49f05 Add .mlir to startup shark_tmp cleanup (#1991)
* Add .mlir to the fiiles that are deleted from `./shark_tmp` when studio
is started.
* refactor/rename existing gradio temp file cleanup on startup to be
consistent with a general `./shark_tmp` cleanup
2023-11-22 14:34:28 -06:00
PhaneeshB
2f780f0d38 quick fix rocm None device 2023-11-22 21:17:25 +05:30
Ean Garvey
d051c3a4a7 Use clean_device_info() by default and don't write .mlir to /tmp/ (#1984)
* Move clean_device_info to compile_utils

* Update compile_utils.py

* Fix .mlir writes for some user-level permissions

* Fix cases where full URI is given

* Fix conditionals.

* Fix device path handling in vulkan utils.
2023-11-20 13:10:31 -06:00
Ean Garvey
1b11c82c9d Small UI tweaks for chatbot, fix torchvision requirements (#1988)
- add torchvision to setup_venv.ps1 -- we need this for the torchvision::nms that is now a dependency of controlnet features.
- Don't have bad flashy orange updates when using the chatbot
- Don't limit the height of the chatbot -- there's mixed opinions and solutions around this one. I think the default (400) is just way too small and LLMs generate plenty enough to justify matching the output.
2023-11-21 00:09:10 +05:30
gpetters94
80a33d427f Save intermediate values of controlnet (#1981) 2023-11-17 19:05:41 -05:00
Stefan Kapusniak
4125a26294 API/Docs: Fix incorrect cors arguments listing (#1983)
* Replace `api_cors_origin` in the api/koboldcpp doc, with the correct
 `api_accept_origin`
2023-11-17 12:29:01 -06:00
Ean Garvey
905d0103ff Revert "Re-enable SD tunings without matmuls. (#1976)" (#1979)
This reverts commit 70817bb50a.
2023-11-17 23:44:33 +05:30
Stefan Kapusniak
192b3b2c61 UI: Output galllery cleanups (#1959)
* Workaround gradio bug that causes the parameters frame to always show
scrollbars.
* Remove the original funky method of setting the number of image
columns in the gallery using _fn= javacript events. The version
of gradio we now have pinned allows doing this by setting the property
on the gallery directly and also doesn't keep resetting the columns on
other events being fired.
2023-11-15 22:20:42 -06:00
Stefan Kapusniak
8f9adc4a2a UI: Display top tag frequencies for selected LoRA (#1972)
* Adds a function to webui utils to read metadata from
.safetensors LoRA files. and do limiting parsing of the format written
out by the Kohya SS scripts (https://github.com/kohya-ss/sd-scripts)
to get tag frequency and trained model information.
* Adds a new common_ui_events.py file for gradio event handlers
needed for multiple UI tabs, and adds an event handler for binding to
the change event of the LoRA selection boxes, that outputs HTML
to display the LoRA tag frequency and model information.
* Adds an HTML gradio control to each of the SD tabs to show the
LoRA model name, and most frequently trained tags.
* Bind the change event of the LoRA selection box on each tab
to our new event handler, with the output set to the relevant HTML
control.
2023-11-15 22:19:54 -06:00
Ean Garvey
70817bb50a Re-enable SD tunings without matmuls. (#1976) 2023-11-15 20:42:53 -06:00
jinchen62
dd37c26d36 Update brevitas quant api (#1975) 2023-11-15 10:04:07 -08:00
PhaneeshB
a708879c6c fix iree version mismatch 2023-11-15 01:24:42 +05:30
Ean Garvey
bb1b49eb6f Add --no-index to setup_venv.sh runtime pip install. 2023-11-14 21:44:20 +05:30
Ean Garvey
f6d41affd9 (SHARK Studio) Add Turbine-based llm chatbot. (#1933)
* Dan shark studio (#1970)

* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>

* Fix formatting.

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-14 09:56:28 -06:00
78 changed files with 6452 additions and 1710 deletions

View File

@@ -74,80 +74,3 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
@@ -121,9 +121,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -144,10 +144,10 @@ jobs:
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'

View File

@@ -25,7 +25,7 @@ from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,

File diff suppressed because it is too large Load Diff

View File

@@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
return torch.tensor(new_hidden_states)
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
def forward(self, hidden_states, attention_mask):
output = self.model.forward(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
return (output[0], output[1][0], output[1][1])
class CompiledDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
new_hidden_states, pkv1, pkv2 = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),
tuple(
[
torch.tensor(pkv1),
torch.tensor(pkv2),
]
),
]
)
class EightDecoderLayer(torch.nn.Module):
class FourWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
@@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module):
outputs[-1][1],
)
)
if self.falcon_variant == "7b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "40b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
return result
class CompiledEightDecoderLayer(torch.nn.Module):
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
@@ -339,12 +170,14 @@ class CompiledEightDecoderLayer(torch.nn.Module):
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@@ -354,24 +187,12 @@ class CompiledEightDecoderLayer(torch.nn.Module):
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
@@ -383,196 +204,452 @@ class CompiledEightDecoderLayer(torch.nn.Module):
attention_mask,
),
)
del self.model
if self.falcon_variant == "7b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
return result
class TwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
elif self.falcon_variant == "40b":
result = (
torch.tensor(output[0]),
hidden_states = outputs[0]
new_pkvs.append(
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
(new_pkv200, new_pkv201),
(new_pkv210, new_pkv211),
(new_pkv220, new_pkv221),
(new_pkv230, new_pkv231),
(new_pkv240, new_pkv241),
(new_pkv250, new_pkv251),
(new_pkv260, new_pkv261),
(new_pkv270, new_pkv271),
(new_pkv280, new_pkv281),
(new_pkv290, new_pkv291),
(new_pkv300, new_pkv301),
(new_pkv310, new_pkv311),
(new_pkv320, new_pkv321),
(new_pkv330, new_pkv331),
(new_pkv340, new_pkv341),
(new_pkv350, new_pkv351),
(new_pkv360, new_pkv361),
(new_pkv370, new_pkv371),
(new_pkv380, new_pkv381),
(new_pkv390, new_pkv391),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
new_pkv200,
new_pkv201,
new_pkv210,
new_pkv211,
new_pkv220,
new_pkv221,
new_pkv230,
new_pkv231,
new_pkv240,
new_pkv241,
new_pkv250,
new_pkv251,
new_pkv260,
new_pkv261,
new_pkv270,
new_pkv271,
new_pkv280,
new_pkv281,
new_pkv290,
new_pkv291,
new_pkv300,
new_pkv301,
new_pkv310,
new_pkv311,
new_pkv320,
new_pkv321,
new_pkv330,
new_pkv331,
new_pkv340,
new_pkv341,
new_pkv350,
new_pkv351,
new_pkv360,
new_pkv361,
new_pkv370,
new_pkv371,
new_pkv380,
new_pkv381,
new_pkv390,
new_pkv391,
)
return result
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
(
torch.tensor(output[41]),
torch.tensor(output[42]),
),
(
torch.tensor(output[43]),
torch.tensor(output[44]),
),
(
torch.tensor(output[45]),
torch.tensor(output[46]),
),
(
torch.tensor(output[47]),
torch.tensor(output[48]),
),
(
torch.tensor(output[49]),
torch.tensor(output[50]),
),
(
torch.tensor(output[51]),
torch.tensor(output[52]),
),
(
torch.tensor(output[53]),
torch.tensor(output[54]),
),
(
torch.tensor(output[55]),
torch.tensor(output[56]),
),
(
torch.tensor(output[57]),
torch.tensor(output[58]),
),
(
torch.tensor(output[59]),
torch.tensor(output[60]),
),
(
torch.tensor(output[61]),
torch.tensor(output[62]),
),
(
torch.tensor(output[63]),
torch.tensor(output[64]),
),
(
torch.tensor(output[65]),
torch.tensor(output[66]),
),
(
torch.tensor(output[67]),
torch.tensor(output[68]),
),
(
torch.tensor(output[69]),
torch.tensor(output[70]),
),
(
torch.tensor(output[71]),
torch.tensor(output[72]),
),
(
torch.tensor(output[73]),
torch.tensor(output[74]),
),
(
torch.tensor(output[75]),
torch.tensor(output[76]),
),
(
torch.tensor(output[77]),
torch.tensor(output[78]),
),
(
torch.tensor(output[79]),
torch.tensor(output[80]),
),
)
return result

View File

@@ -5,7 +5,7 @@ from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicuna(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicuna(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -79,7 +81,9 @@ class SecondVicuna7B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -91,7 +95,7 @@ class SecondVicuna7B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -329,7 +333,9 @@ class SecondVicuna13B(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -341,7 +347,7 @@ class SecondVicuna13B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -627,7 +633,9 @@ class SecondVicuna70B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -639,7 +647,7 @@ class SecondVicuna70B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicunaGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicunaGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -78,7 +80,9 @@ class SecondVicuna7BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -90,7 +94,7 @@ class SecondVicuna7BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -327,7 +331,9 @@ class SecondVicuna13BGPU(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -339,7 +345,7 @@ class SecondVicuna13BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -625,7 +631,9 @@ class SecondVicuna70BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -637,7 +645,7 @@ class SecondVicuna70BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -1,4 +1,5 @@
import torch
import time
class FirstVicunaLayer(torch.nn.Module):
@@ -66,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -110,9 +110,11 @@ class LMHeadCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
hidden_states_sample = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
@@ -136,8 +138,9 @@ class VicunaNormCompiled(torch.nn.Module):
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = self.model("forward", (hidden_states,), send_to_host=True)
output = torch.tensor(output)
return output
@@ -158,15 +161,18 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = self.model("forward", (input_ids,), send_to_host=True)
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
def __init__(self, shark_module, idx, breakpoints):
super().__init__()
self.model = shark_module
self.idx = idx
self.breakpoints = breakpoints
def forward(
self,
@@ -177,10 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
if self.breakpoints is None:
is_breakpoint = False
else:
is_breakpoint = self.idx + 1 in self.breakpoints
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
@@ -188,11 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
attention_mask,
position_ids,
),
send_to_host=is_breakpoint,
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,
@@ -202,11 +215,8 @@ class CompiledVicunaLayer(torch.nn.Module):
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
"second_vicuna_forward",
(
@@ -216,11 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
pkv0,
pkv1,
),
send_to_host=is_breakpoint,
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
if is_breakpoint:
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
else:
output0 = output[0]
output1 = output[1]
output2 = output[2]
return (
output0,

View File

@@ -6,10 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
FourWayShardingDecoderLayer,
TwoWayShardingDecoderLayer,
CompiledFourWayShardingDecoderLayer,
CompiledTwoWayShardingDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -94,6 +94,13 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
parser.add_argument(
"--num_shards",
type=int,
default=4,
choices=[2, 4],
help="Number of shards.",
)
class ShardedFalcon(SharkLLMBase):
@@ -122,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
if args.sharded and "180b" not in self.model_name:
raise ValueError("Sharding supported only for Falcon-180B")
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
@@ -131,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile(compressed=args.compressed)
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -146,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(
@@ -225,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
f16_input_mask = [True, True]
else:
raise ValueError("Unsupported layer: ", layer_id)
@@ -288,28 +296,16 @@ class ShardedFalcon(SharkLLMBase):
return shark_module, device_idx
def compile(self, compressed=False):
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
[1, 100, 14848], dtype=torch.float32
)
# Determine number of available devices
@@ -319,6 +315,10 @@ class ShardedFalcon(SharkLLMBase):
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if num_devices < 2:
raise ValueError(
"Cannot run Falcon-180B on a single ROCM device."
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
@@ -326,7 +326,9 @@ class ShardedFalcon(SharkLLMBase):
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
device_idx=(0 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -338,7 +340,9 @@ class ShardedFalcon(SharkLLMBase):
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
device_idx=(1 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
@@ -350,7 +354,9 @@ class ShardedFalcon(SharkLLMBase):
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
device_idx=(2 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
@@ -360,24 +366,21 @@ class ShardedFalcon(SharkLLMBase):
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = FourWayShardingDecoderLayer
compiled_class = CompiledFourWayShardingDecoderLayer
if args.num_shards == 2:
pytorch_class = TwoWayShardingDecoderLayer
compiled_class = CompiledTwoWayShardingDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
@@ -388,13 +391,13 @@ class ShardedFalcon(SharkLLMBase):
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
shark_module,
)
shark_layers.append(shark_layer_i)
@@ -668,20 +671,17 @@ class UnshardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):

View File

@@ -132,7 +132,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

View File

@@ -4,13 +4,49 @@ from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
import argparse
parser = argparse.ArgumentParser(
prog="stablelm runner",
description="runs a StableLM model",
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
)
parser.add_argument(
"--stablelm_mlir_path",
default=None,
help="path to StableLM's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for stablelm-3B model.",
)
class StopOnTokens(StoppingCriteria):
@@ -29,7 +65,7 @@ class SharkStableLM(SharkLLMBase):
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
max_num_tokens=256,
device="cuda",
precision="fp32",
debug="False",
@@ -37,6 +73,14 @@ class SharkStableLM(SharkLLMBase):
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
if precision != "int4" and args.hf_auth_token == None:
raise ValueError(
""" HF auth token required for StableLM-3B. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = args.hf_auth_token
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
@@ -50,9 +94,23 @@ class SharkStableLM(SharkLLMBase):
return False
def get_src_model(self):
kwargs = {}
if self.precision == "int4":
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
from transformers import GPTQConfig
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
print("[DEBUG] Loading Model")
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token=self.hf_auth_token,
**kwargs,
)
print("[DEBUG] Model loaded successfully")
return model
def get_model_inputs(self):
@@ -61,9 +119,7 @@ class SharkStableLM(SharkLLMBase):
return input_ids, attention_mask
def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"
# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
@@ -83,13 +139,19 @@ class SharkStableLM(SharkLLMBase):
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if not mlir_path.exists():
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
from shark.shark_importer import import_with_fx
ts_graph = import_with_fx(
model,
model_inputs,
is_f16=True if self.precision in ["fp16"] else False,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
@@ -100,15 +162,16 @@ class SharkStableLM(SharkLLMBase):
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
print("Saved mlir at: ", mlir_path)
f_.close()
del bytecode
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()
@@ -120,14 +183,22 @@ class SharkStableLM(SharkLLMBase):
return shark_module
def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_auth_token=self.hf_auth_token,
)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok
def generate(self, prompt):
words_list = []
import time
start = time.time()
count = 0
for i in range(self.max_num_tokens):
count = count + 1
params = {
"new_text": prompt,
}
@@ -145,6 +216,12 @@ class SharkStableLM(SharkLLMBase):
if detok == "":
break
prompt = prompt + detok
end = time.time()
print(
"\n\nTime taken is {:.2f} tokens/second\n".format(
count / (end - start)
)
)
return words_list
def generate_new_token(self, params):
@@ -178,10 +255,46 @@ class SharkStableLM(SharkLLMBase):
return ret_dict
# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
if __name__ == "__main__":
args = parser.parse_args()
stable_lm = SharkStableLM(
model_name="stablelm_zephyr_3b",
hf_model_path="stabilityai/stablelm-zephyr-3b",
device=args.device,
precision=args.precision,
)
default_prompt_text = "The weather is always wonderful"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("StableLM Model: ", stable_lm.hf_model_path)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = stable_lm.generate(prompt)
torch.cuda.empty_cache()
import gc
gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
prompt + "".join(res_str),
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

View File

@@ -105,6 +105,7 @@ def main():
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
control_mode=args.control_mode,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -31,6 +31,7 @@ datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
@@ -60,6 +61,7 @@ datas += [
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",
@@ -75,6 +77,7 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
@@ -85,4 +88,4 @@ hiddenimports += [
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
hiddenimports += ["iree._runtime"]

View File

@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from collections import defaultdict
from pathlib import Path
import torch
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
get_stencil_model_id,
update_lora_weight,
)
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
# These shapes are parameter dependent.
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
)
def shark_compile_after_ir(
module_name,
device,
vmfb_path,
precision,
ir_path=None,
):
if ir_path:
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
module = SharkInference(
mlir_module=ir_path,
device=device,
mlir_dialect="tm_tensor",
)
print(f"Will get extra flag for {module_name} and precision = {precision}")
path = module.save_module(
vmfb_path.parent.absolute(),
vmfb_path.stem,
extra_args=get_opt_flags(module_name, precision=precision),
)
print(f"Saved {module_name} vmfb at {path}")
module.load_module(path)
return module
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
name_split = extended_model_name.split("_")
if "vae" in model_name:
name_split[5] = "fp32"
extended_model_name_for_vmfb = "_".join(name_split)
extended_model_name_for_mlir = "_".join(name_split[:-1])
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
device=device,
mlir_dialect="tm_tensor",
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=[])
return shark_module, None
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
if not mlir_path.exists():
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
download_public_file(
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
return (
shark_compile_after_ir(
model_name, device, vmfb_path, precision, mlir_path
),
None,
)
return None, None
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -86,13 +156,17 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
lora_strength: float = 0.75,
use_quantize: str = None,
return_mlir: bool = False,
favored_base_models=None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
@@ -118,9 +192,7 @@ class SharkifyStableDiffusionModel:
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
self.model_id = "stabilityai/stable-diffusion-2-1-base"
self.favored_base_models = favored_base_models
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
@@ -136,6 +208,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
self.model_namedata = self.model_name
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
@@ -144,12 +217,17 @@ class SharkifyStableDiffusionModel:
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.model_name = (
self.model_name
+ "_"
+ get_path_stem(use_lora)
+ f"@{int(lora_strength*100)}"
)
self.use_lora = use_lora
self.lora_strength = lora_strength
print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
@@ -171,19 +249,22 @@ class SharkifyStableDiffusionModel:
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir
def get_extended_name_for_all_model(self):
def get_extended_name_for_all_model(self, model_list=None):
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
"stencil_adapter",
"stencil_adapter_512",
]
if model_list is not None:
sub_model_list = model_list
index = 0
for model in sub_model_list:
sub_model = model
@@ -195,10 +276,24 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
if "stencil_adapter" in model:
stencil_names = []
for i, stencil in enumerate(self.stencils):
if stencil is not None:
cnet_config = (
self.model_namedata
+ "_sd15_"
+ stencil.split("_")[-1]
)
stencil_names.append(
get_extended_name(sub_model + cnet_config)
)
model_name[model] = stencil_names
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
@@ -342,6 +437,105 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["vae"], "vae", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
print(f"Loading default vae, with target {model_id}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
custom_vae = "/".join(
[
custom_vae.split("/")[-2].split("\\")[-1],
custom_vae.split("/")[-1],
]
)
print("Using hub to get custom vae")
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
if not self.custom_vae:
is_f16 = False
elif "16" in self.custom_vae:
is_f16 = True
else:
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -349,6 +543,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -357,8 +552,10 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
def forward(
@@ -380,25 +577,54 @@ class SharkifyStableDiffusionModel:
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control1 * scale1,
control2 * scale2,
control3 * scale3,
control4 * scale4,
control5 * scale5,
control6 * scale6,
control7 * scale7,
control8 * scale8,
control9 * scale9,
control10 * scale10,
control11 * scale11,
control12 * scale12,
]
)
mb_res_samples = control13
mb_res_samples = control13 * scale13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
@@ -446,6 +672,19 @@ class SharkifyStableDiffusionModel:
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
@@ -462,17 +701,19 @@ class SharkifyStableDiffusionModel:
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
print(f"Importing ControlNet adapter from {stencil_id}")
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.in_channels = self.cnet.config.in_channels
self.train(False)
def forward(
@@ -481,6 +722,19 @@ class SharkifyStableDiffusionModel:
timestep,
text_embedding,
stencil_image_input,
acc1,
acc2,
acc3,
acc4,
acc5,
acc6,
acc7,
acc8,
acc9,
acc10,
acc11,
acc12,
acc13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
@@ -502,6 +756,20 @@ class SharkifyStableDiffusionModel:
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
) + (
acc1 + down_block_res_samples[0],
acc2 + down_block_res_samples[1],
acc3 + down_block_res_samples[2],
acc4 + down_block_res_samples[3],
acc5 + down_block_res_samples[4],
acc6 + down_block_res_samples[5],
acc7 + down_block_res_samples[6],
acc8 + down_block_res_samples[7],
acc9 + down_block_res_samples[8],
acc10 + down_block_res_samples[9],
acc11 + down_block_res_samples[10],
acc12 + down_block_res_samples[11],
acc13 + mid_block_res_sample,
)
scnet = StencilControlNetModel(
@@ -509,7 +777,25 @@ class SharkifyStableDiffusionModel:
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
stencil_names = self.get_extended_name_for_all_model([model_name])
ext_model_name = stencil_names[model_name]
if isinstance(ext_model_name, list):
desired_name = None
print(ext_model_name)
for i in ext_model_name:
if stencil_id.split("_")[-1] in i:
desired_name = i
else:
continue
if desired_name:
ext_model_name = desired_name
else:
raise Exception(
f"Could not find extended configuration for {stencil_id}"
)
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
@@ -517,21 +803,15 @@ class SharkifyStableDiffusionModel:
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
input_mask = [True, True, True, True] + ([True] * 13)
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=ext_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -550,6 +830,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -558,7 +839,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
if (
@@ -688,6 +971,101 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["unet"], "unet", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -695,6 +1073,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
@@ -704,7 +1083,10 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
self.text_encoder,
use_lora,
"text_encoder",
lora_strength,
)
def forward(self, input):
@@ -735,6 +1117,78 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
if clip_index == 1:
extended_model_name = self.model_name["clip"]
model_name = "clip"
else:
extended_model_name = self.model_name["clip2"]
model_name = "clip2"
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
extended_model_name, f"clip", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -767,7 +1221,9 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -809,21 +1265,53 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]
# if the model to run *is* a base model, then we should treat it as such
if self.model_to_run in unet_inputs:
self.base_model_id = self.model_to_run
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
# restrict base models to check if we were given a specific list of valid ones
allowed_base_model_ids = unet_inputs
if self.favored_base_models != None:
allowed_base_model_ids = self.favored_base_models
print(f"self.favored_base_models: {self.favored_base_models}")
print(f"allowed_base_model_ids: {allowed_base_model_ids}")
# try compiling with each base model until we find one that works (of not)
for model_id in allowed_base_model_ids:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
@@ -831,12 +1319,12 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=model_id
)
except Exception as e:
print(e)
print(
"Retrying with a different base model configuration"
f"Retrying with a different base model configuration, as {model_id} did not work"
)
continue
@@ -870,7 +1358,10 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
@@ -880,18 +1371,18 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
self.inputs["stencil_adapter"] = self.get_input_info_for(
base_models["stencil_adapter"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
check_compilation(compiled_stencil_adapter, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
return compiled_stencil_adapter
except Exception as e:
sys.exit(e)

View File

@@ -123,8 +123,11 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
args.hf_model_id, subfolder=subfolder
)
return tokenizer

View File

@@ -1,6 +1,9 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):
@@ -158,8 +161,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
images,
resample_type,
control_mode,
preprocessed_hints=[],
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):

View File

@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -25,12 +25,22 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
from apps.stable_diffusion.src.utils import (
controlnet_hint_conversion,
controlnet_hint_reshaping,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from apps.stable_diffusion.src.utils import (
resamplers,
resampler_list,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
class StencilPipeline(StableDiffusionPipeline):
@@ -54,29 +64,69 @@ class StencilPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
self.vae_encode = None
def load_controlnet(self):
if self.controlnet is not None:
def load_vae_encode(self):
if self.vae_encode is not None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
if self.import_mlir or self.use_lora:
self.vae_encode = self.sd_model.vae_encode()
else:
try:
self.vae_encode = get_vae_encode()
except:
print("download pipeline failed, falling back to import_mlir")
self.vae_encode = self.sd_model.vae_encode()
def load_controlnet_512(self):
if self.controlnet_512 is not None:
def unload_vae_encode(self):
del self.vae_encode
self.vae_encode = None
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
@@ -103,6 +153,58 @@ class StencilPipeline(StableDiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre-process image
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
# timesteps reduced as per strength
timesteps = self.scheduler.timesteps[t_start:]
# new number of steps to be used as per strength will be
# num_inference_steps = num_inference_steps - t_start
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# add noise to data
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
)
return latents, timesteps
def produce_stencil_latents(
self,
latents,
@@ -111,8 +213,9 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
control_mode="Balanced", # Prompt, Balanced, or Controlnet
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -121,12 +224,24 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, name in enumerate(self.controlnet_names):
use_names = []
if name is not None:
use_names.append(name)
else:
continue
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
self.controlnet_names = use_names
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -149,33 +264,93 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
# Multicontrolnet
height = latent_model_input_1.shape[2]
width = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
+ [
torch.zeros(
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
)
]
* 4
)
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
pass
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
else:
control = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
control_acc = control[13:]
control = control[:13]
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
dtype = latents.dtype
if control_mode == "Balanced":
control_scale = [
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
]
elif control_mode == "Prompt":
control_scale = [
torch.tensor(0.825**x, dtype=dtype)
for x in range(len(control))
]
elif control_mode == "Controlnet":
control_scale = [
torch.tensor(float(guidance_scale), dtype=dtype)
for _ in range(len(control))
]
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
@@ -197,11 +372,23 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
@@ -222,6 +409,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -245,8 +445,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -255,6 +456,17 @@ class StencilPipeline(StableDiffusionPipeline):
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def encode_image(self, input_image):
self.load_vae_encode()
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
if self.ondemand:
self.unload_vae_encode()
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
@@ -272,14 +484,48 @@ class StencilPipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
control_mode,
preprocessed_hints,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
self.sd_model.stencils = stencils
for i, hint in enumerate(preprocessed_hints):
if hint is not None:
hint = controlnet_hint_reshaping(
hint,
height,
width,
dtype,
num_images_per_prompt=1,
)
stencil_hints.append(hint)
for i, stencil in enumerate(stencils):
if stencil == None:
continue
if len(stencil_hints) > i:
if stencil_hints[i] is not None:
print(f"Using preprocessed controlnet hint for {stencil}")
continue
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -307,17 +553,30 @@ class StencilPipeline(StableDiffusionPipeline):
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
if image is not None:
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
else:
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps
# Get Image latents
latents = self.produce_stencil_latents(
@@ -327,7 +586,8 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
control_mode=control_mode,
stencil_hints=stencil_hints,
)
# Img latents -> PIL images

View File

@@ -18,7 +18,10 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
@@ -46,9 +49,19 @@ class Text2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
@classmethod
def favored_base_models(cls, model_id):
return [
"stabilityai/stable-diffusion-2-1",
"CompVis/stable-diffusion-v1-4",
]
def prepare_latents(
self,

View File

@@ -0,0 +1,236 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.is_fp32_vae = is_fp32_vae
@classmethod
def favored_base_models(cls, model_id):
if "turbo" in model_id:
return [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
else:
return [
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE

View File

@@ -20,7 +20,10 @@ from diffusers import (
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
@@ -33,6 +36,8 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -50,6 +55,7 @@ class StableDiffusionPipeline:
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
@@ -59,21 +65,26 @@ class StableDiffusionPipeline:
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_f32_vae: bool = False,
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.lora_strength = lora_strength
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
@@ -83,6 +94,10 @@ class StableDiffusionPipeline:
self.unload_unet()
self.tokenizer = get_tokenizer()
def favored_base_models(cls, model_id):
# all base models can be candidate base models for unet compilation
return None
def load_clip(self):
if self.text_encoder is not None:
return
@@ -106,6 +121,34 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -159,6 +202,182 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -186,6 +405,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -298,6 +518,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -306,6 +528,96 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# return None
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
if isinstance(latents, np.ndarray):
latents = torch.tensor(latents)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=True,
)
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to("cpu")
noise_pred = torch.from_numpy(noise_pred).to("cpu")
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.detach().numpy()
noise_pred = noise_pred.detach().numpy()
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents, is_fp32_vae):
# latents are in unet dtype here so switch if we want to use fp32
if is_fp32_vae:
print("Casting latents to float32 for VAE")
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -338,8 +650,10 @@ class StableDiffusionPipeline:
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
lora_strength: float = 0.75,
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
@@ -355,7 +669,11 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
print(f"model_id", model_id)
print(f"ckpt_loc", ckpt_loc)
print(f"favored_base_models:", cls.favored_base_models(model_id))
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
@@ -371,9 +689,14 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
lora_strength=lora_strength,
use_quantize=use_quantize,
favored_base_models=cls.favored_base_models(
model_id if model_id != "" else ckpt_loc
),
)
if cls.__name__ in ["UpscalerPipeline"]:
@@ -383,10 +706,35 @@ class StableDiffusionPipeline:
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
stencils,
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
is_fp32_vae,
)
return cls(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
# #####################################################
# Implements text embeddings with weights from prompts
@@ -498,9 +846,10 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
return text_embeddings.numpy().astype(np.float16)
from typing import List, Optional, Union

View File

@@ -1,4 +1,7 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers

View File

@@ -1,4 +1,5 @@
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
@@ -15,9 +16,21 @@ from diffusers import (
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: Robust scheduler setup on pipeline creation -- if we don't
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.
batch_size = 2 if "xl" in model_id.lower() else 1
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerAncestralDiscrete"
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
schedulers["SharkEulerDiscrete"].compile(batch_size)
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
return schedulers

View File

@@ -0,0 +1,251 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
EulerAncestralDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
timestep_spacing,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
self.init_input_shape = None
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"eulera": {
"output": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"latent": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"sigma_from": torch.tensor(1).to(torch.float32),
"sigma_to": torch.tensor(1).to(torch.float32),
"noise": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
},
}
example_latent = model_input["eulera"]["latent"]
example_output = model_input["eulera"]["output"]
example_noise = model_input["eulera"]["noise"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_noise = example_noise.half()
example_sigma = model_input["eulera"]["sigma"]
example_sigma_from = model_input["eulera"]["sigma_from"]
example_sigma_to = model_input["eulera"]["sigma_to"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(
example_output,
example_latent,
example_sigma,
example_sigma_from,
example_sigma_to,
example_noise,
),
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_step_"
+ self.config.prediction_type
+ args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(
self,
noise_pred,
timestep,
latent,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
step_inputs = []
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
step_inputs = [
noise_pred,
latent,
sigma,
sigma_from,
sigma_to,
noise,
]
# TODO: deal with dynamic inputs in turbine flow.
# update step index since we're done with the variable and will return with compiled module output.
self._step_index += 1
if noise_pred.shape[0] < self.batch_size:
for i in [0, 1, 5]:
try:
step_inputs[i] = torch.tensor(step_inputs[i])
except:
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
step_inputs[i] = torch.cat(
(step_inputs[i], step_inputs[i]), axis=0
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=True,
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=False,
)

View File

@@ -2,12 +2,9 @@ import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule,
trained_betas,
prediction_type,
interpolation_type,
use_karras_sigmas,
sigma_min,
sigma_max,
timestep_spacing,
timestep_type,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = 1
def compile(self):
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = latent - sigma_hat * noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepSampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
@@ -90,16 +123,22 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
"sample": SchedulerStepSampleModel(),
"original_sample": SchedulerStepSampleModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
@@ -109,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
else:
try:
step_model_type = (
"sample"
if "sample" in self.config.prediction_type
else self.config.prediction_type
)
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
@@ -116,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
"euler_step_" + step_model_type + args.precision,
iree_flags,
)
except:
@@ -127,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
@@ -138,15 +183,61 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
def step(
self,
noise_pred,
timestep,
latent,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
gamma = (
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigma <= s_tmax
else 0.0
)
sigma_hat = sigma * (gamma + 1)
noise_pred = (
torch.from_numpy(noise_pred)
if isinstance(noise_pred, np.ndarray)
else noise_pred
)
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
eps = noise * s_noise
if gamma > 0:
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.config.prediction_type == "v_prediction":
sigma_hat = sigma
dt = self.sigmas[self.step_index + 1] - sigma_hat
self._step_index += 1
return self.step_model(
"forward",
(
noise_pred,
sigma,
sigma_hat,
latent,
dt,
),

View File

@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
controlnet_hint_conversion,
controlnet_hint_reshaping,
get_stencil_model_id,
)
from apps.stable_diffusion.src.utils.utils import (

View File

@@ -8,6 +8,15 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -179,9 +188,95 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/sdxl-turbo": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {
"stencil_adapter": {
"latents": {
"shape": [
"1*batch_size",
@@ -208,6 +303,58 @@
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
},
"acc1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"acc5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"acc8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"stencil_unet": {
@@ -290,7 +437,59 @@
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"scale1": {
"shape": 1,
"dtype": "f32"
},
"scale2": {
"shape": 1,
"dtype": "f32"
},
"scale3": {
"shape": 1,
"dtype": "f32"
},
"scale4": {
"shape": 1,
"dtype": "f32"
},
"scale5": {
"shape": 1,
"dtype": "f32"
},
"scale6": {
"shape": 1,
"dtype": "f32"
},
"scale7": {
"shape": 1,
"dtype": "f32"
},
"scale8": {
"shape": 1,
"dtype": "f32"
},
"scale9": {
"shape": 1,
"dtype": "f32"
},
"scale10": {
"shape": 1,
"dtype": "f32"
},
"scale11": {
"shape": 1,
"dtype": "f32"
},
"scale12": {
"shape": 1,
"dtype": "f32"
},
"scale13": {
"shape": 1,
"dtype": "f32"
}
}
}
}
}

View File

@@ -59,24 +59,28 @@
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
}

View File

@@ -1,4 +1,5 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],

View File

@@ -85,7 +85,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The height of the output image.",
)
@@ -93,7 +93,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The width of the output image.",
)
@@ -420,6 +420,13 @@ p.add_argument(
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
@@ -428,6 +435,13 @@ p.add_argument(
"file (~3 MB).",
)
p.add_argument(
"--lora_strength",
type=float,
default=1.0,
help="Strength (alpha) scaling factor to use when applying LoRA weights",
)
p.add_argument(
"--use_quantize",
type=str,
@@ -460,6 +474,13 @@ p.add_argument(
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
p.add_argument(
"--autogen",
type=bool,
default="False",
help="Only used for a gradio workaround.",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -587,6 +608,13 @@ p.add_argument(
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
##############################################################################

View File

@@ -1,6 +1,10 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
@@ -10,6 +14,31 @@ from apps.stable_diffusion.src.utils.stencils import (
stencil = {}
def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints")
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
@@ -29,7 +58,7 @@ def HWC3(x):
return y
def controlnet_hint_shaping(
def controlnet_hint_reshaping(
controlnet_hint, height, width, dtype, num_images_per_prompt=1
):
channels = 3
@@ -48,10 +77,12 @@ def controlnet_hint_shaping(
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
return controlnet_hint_reshaping(
Image.fromarray(controlnet_hint.detach().numpy()),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
@@ -78,29 +109,38 @@ def controlnet_hint_shaping(
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
return controlnet_hint_reshaping(
Image.fromarray(controlnet_hint),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, Image.Image):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = np.array(controlnet_hint).astype(
np.float16
) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
return controlnet_hint_reshaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
(hint_w, hint_h) = controlnet_hint.size
left = int((hint_w - width) / 2)
right = left + height
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
controlnet_hint = controlnet_hint.resize((width, height))
return controlnet_hint_reshaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@@ -110,20 +150,26 @@ def controlnet_hint_conversion(
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
print(
"Converting controlnet hint to edge detection mask with canny preprocessor."
)
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
print(
"Detecting human pose in controlnet hint with openpose preprocessor."
)
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
print("Using your scribble as a controlnet hint.")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
print(
"Converting controlnet hint to a depth mapping with ZoeDepth."
)
controlnet_hint = hint_zoedepth(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
controlnet_hint = controlnet_hint_reshaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
return controlnet_hint
@@ -161,6 +207,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -176,6 +223,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -187,6 +235,7 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map
@@ -199,5 +248,6 @@ def hint_zoedepth(image: Image.Image):
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -30,9 +30,15 @@ class ZoeDetector:
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
# Hack to fix the ZoeDepth import issue
model_keys = model.state_dict().keys()
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
loaded_keys = loaded_dict.keys()
for key in loaded_keys - model_keys:
loaded_dict.pop(key)
model.load_state_dict(loaded_dict)
model.eval()
self.model = model

View File

@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from random import (
@@ -118,7 +119,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,
@@ -541,6 +542,8 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if args.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -563,6 +566,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
@@ -632,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
return converted_vae_checkpoint
def processLoRA(model, use_lora, splitting_prefix):
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight. Unsure if this is a safe assumption.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if ".alpha" in key or key in visited:
continue
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0,
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
@@ -672,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = (
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
)
weight_down = (
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
change = (
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora):
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
main_file_name = get_path_stem(use_lora)
@@ -727,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
print(f"updated unet weights manually from LoRA: {use_lora}")
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name):
def update_lora_weight(model, use_lora, model_name, lora_strength):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_")
print(f"updating CLIP weights from LoRA: {use_lora}")
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
@@ -892,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
@@ -1002,8 +1035,7 @@ def get_generation_text_info(seeds, device):
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image):
width, height = image.size
def resize_stencil(image: Image.Image, width, height):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
decode_base64_to_image,
get_model_from_request,
get_scheduler_from_request,
get_lora_params,
get_device,
GenerationInputData,
GenerationResponseData,
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
scheduler = get_scheduler_from_request(
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
)
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
print(
f"Prompt: {InputData.prompt}, "
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
@@ -270,7 +268,6 @@ def img2img_api(
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(InputData, "img2img")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
mask_image = (
@@ -308,8 +305,8 @@ def img2img_api(
use_stencil=InputData.use_stencil,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
@@ -358,7 +355,6 @@ def inpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "inpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.image)
mask = decode_base64_to_image(InputData.mask)
@@ -374,7 +370,8 @@ def inpaint_api(
res = inpaint_inf(
InputData.prompt,
InputData.negative_prompt,
{"image": init_image, "mask": mask},
init_image,
mask,
InputData.height,
InputData.width,
InputData.is_full_res,
@@ -392,8 +389,8 @@ def inpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -447,7 +444,6 @@ def outpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "outpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -483,8 +479,8 @@ def outpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -530,7 +526,6 @@ def upscaler_api(
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
scheduler = get_scheduler_from_request(InputData, "upscaler")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -562,8 +557,8 @@ def upscaler_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)

View File

@@ -191,17 +191,6 @@ def get_scheduler_from_request(
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched

View File

@@ -75,17 +75,20 @@ if __name__ == "__main__":
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
config_gradio_tmp_imgs_folder,
from apps.stable_diffusion.web.utils.tmp_configs import (
config_tmp,
shark_tmp,
)
config_gradio_tmp_imgs_folder()
config_tmp()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import (
create_custom_models_folders,
nodicon_loc,
mask_editor_value_for_gallery_data,
mask_editor_value_for_image_file,
)
create_custom_models_folders()
@@ -97,8 +100,6 @@ if __name__ == "__main__":
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
@@ -109,6 +110,16 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_png_info_img,
txt2img_sdxl_status,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
@@ -145,7 +156,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
# lora_train_web,
# model_web,
# model_config_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -159,6 +170,7 @@ if __name__ == "__main__":
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
@@ -168,11 +180,21 @@ if __name__ == "__main__":
# init global sd pipeline and config
global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
def register_sendto_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
x.root[0].image.path if len(x.root) != 0 else None,
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
)
def register_sendto_editor_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
mask_editor_value_for_gallery_data(x),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
@@ -183,24 +205,45 @@ if __name__ == "__main__":
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
queue=False,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
def register_outputgallery_sendto_button(
button, selectedid, inputs, outputs
):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_sendto_editor_button(
button, selectedid, inputs, outputs
):
button.click(
lambda x: (
mask_editor_value_for_image_file(x),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="SHARK AI Studio",
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
@@ -225,34 +268,37 @@ if __name__ == "__main__":
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
outputgallery_tab_select(og_tab.select)
outputgallery_watch(
[
txt2img_status,
img2img_status,
inpaint_status,
outpaint_status,
upscaler_status,
]
)
# with gr.TabItem(label="Model Manager", id=6):
# model_web.render()
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
txt2img_sdxl_web.render()
# extra output gallery configuration
outputgallery_tab_select(og_tab.select)
outputgallery_watch(
[
txt2img_status,
img2img_status,
inpaint_status,
outpaint_status,
upscaler_status,
txt2img_sdxl_status,
],
)
actual_port = app.usable_port()
if actual_port != args.server_port:
@@ -264,133 +310,139 @@ if __name__ == "__main__":
)
# send to buttons
register_button_click(
register_sendto_click(
txt2img_sendto_img2img,
1,
[txt2img_gallery],
[img2img_init_image, tabs],
)
register_button_click(
register_sendto_editor_click(
txt2img_sendto_inpaint,
2,
[txt2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
txt2img_sendto_outpaint,
3,
[txt2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
txt2img_sendto_upscaler,
4,
[txt2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
register_sendto_editor_click(
img2img_sendto_inpaint,
2,
[img2img_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
img2img_sendto_outpaint,
3,
[img2img_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
img2img_sendto_upscaler,
4,
[img2img_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
register_sendto_click(
inpaint_sendto_img2img,
1,
[inpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
register_sendto_click(
inpaint_sendto_outpaint,
3,
[inpaint_gallery],
[outpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
inpaint_sendto_upscaler,
4,
[inpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
register_sendto_click(
outpaint_sendto_img2img,
1,
[outpaint_gallery],
[img2img_init_image, tabs],
)
register_button_click(
register_sendto_editor_click(
outpaint_sendto_inpaint,
2,
[outpaint_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
outpaint_sendto_upscaler,
4,
[outpaint_gallery],
[upscaler_init_image, tabs],
)
register_button_click(
register_sendto_click(
upscaler_sendto_img2img,
1,
[upscaler_gallery],
[img2img_init_image, tabs],
)
register_button_click(
register_sendto_editor_click(
upscaler_sendto_inpaint,
2,
[upscaler_gallery],
[inpaint_init_image, tabs],
)
register_button_click(
register_sendto_click(
upscaler_sendto_outpaint,
3,
[upscaler_gallery],
[outpaint_init_image, tabs],
)
if args.output_gallery:
register_outputgallery_button(
register_outputgallery_sendto_button(
outputgallery_sendto_txt2img,
0,
[outputgallery_filename],
[txt2img_png_info_img, tabs],
)
register_outputgallery_button(
register_outputgallery_sendto_button(
outputgallery_sendto_img2img,
1,
[outputgallery_filename],
[img2img_init_image, tabs],
)
register_outputgallery_button(
register_outputgallery_sendto_editor_button(
outputgallery_sendto_inpaint,
2,
[outputgallery_filename],
[inpaint_init_image, tabs],
)
register_outputgallery_button(
register_outputgallery_sendto_button(
outputgallery_sendto_outpaint,
3,
[outputgallery_filename],
[outpaint_init_image, tabs],
)
register_outputgallery_button(
register_outputgallery_sendto_button(
outputgallery_sendto_upscaler,
4,
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_outputgallery_sendto_button(
outputgallery_sendto_txt2img_sdxl,
0,
[outputgallery_filename],
[txt2img_sdxl_png_info_img, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,

View File

@@ -10,6 +10,18 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
txt2img_sdxl_png_info_img,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_web,
@@ -76,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,

View File

@@ -0,0 +1,64 @@
import gradio as gr
from apps.stable_diffusion.web.ui.utils import (
HSLHue,
hsl_color,
get_lora_metadata,
)
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
def lora_strength_changed(strength):
if strength > 1.0:
return gr.Number(elem_classes="value-out-of-range")
else:
return gr.Number(elem_classes="")

View File

@@ -105,7 +105,19 @@ body {
background-color: var(--background-fill-primary);
}
/* display in full width for desktop devices */
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices, but see below */
@media (min-width: 1536px)
{
.gradio-container {
@@ -113,12 +125,17 @@ body {
}
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define a class which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gradio-container-size-full {
max-width: var(--size-full) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
#top_logo {
@@ -128,6 +145,10 @@ body {
border: 0;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#demo_title_outer {
border-radius: 0;
}
@@ -170,6 +191,8 @@ footer {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
/* fix width and height of gallery items when on very large desktop screens, but see below */
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
@@ -181,6 +204,20 @@ footer {
max-height: 770px !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define classes which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
min-height: calc(768px + 4px + var(--size-14)) !important;
max-height: calc(768px + 4px + var(--size-14)) !important;
}
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
@@ -222,18 +259,19 @@ footer {
display:none;
}
/* Hide the download icon from the nod logo */
#top_logo button {
display: none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
#output_subdir_container {
background-color: var(--block-background-fill);
padding-right: 8px;
}
/* number input value is out of range */
.value-out-of-range input[type="number"] {
color: red !important;
}
/* reduced animation load when generating */
@@ -246,16 +284,54 @@ footer {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe .cell-wrap span {
/* inadequate workaround for gradio issue #6086 */
user-select:text !important;
-moz-user-select:text !important;
-webkit-user-select:text !important;
-o-user-select:text !important;
-ms-user-select:text !important;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
line-height: var(--line-xs);
}
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
padding-bottom: 16px !important;
}
.outputgallery_sendto {
@@ -272,6 +348,11 @@ footer {
object-fit: contain !important;
}
/* use the whole gallery area for previeews */
#outputgallery_gallery .preview {
width: inherit;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;

View File

@@ -5,6 +5,13 @@ import gradio as gr
import PIL
from math import ceil
from PIL import Image
from gradio.components.image_editor import (
Brush,
Eraser,
EditorData,
EditorValue,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -14,6 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -29,6 +40,11 @@ from apps.stable_diffusion.src.utils import (
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -58,14 +74,17 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
preprocessed_hints: list,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -87,14 +106,25 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
continue
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
if image_dict is None and images[0] is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
# TODO: enable t2i + controlnets
image = None
if image:
image, _, _ = resize_stencil(image, width, height)
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -114,19 +144,18 @@ def img2img_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
stencil_count = 0
for stencil in stencils:
if stencil is not None:
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
@@ -136,6 +165,7 @@ def img2img_inf(
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
print(stencils)
new_config_obj = Config(
"img2img",
args.hf_model_id,
@@ -148,13 +178,19 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
lora_strength=args.lora_strength,
stencils=stencils,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
or any(
global_obj.get_cfg_obj().stencils[idx] != stencil
for idx, stencil in enumerate(stencils)
)
):
print("clearing config because you changed something important")
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
@@ -170,12 +206,12 @@ def img2img_inf(
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
else "runwayml/stable-diffusion-v1-5"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
if stencil_count > 0:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
@@ -192,9 +228,10 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -216,6 +253,7 @@ def img2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -249,8 +287,11 @@ def img2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
stencils,
images,
resample_type=resample_type,
control_mode=control_mode,
preprocessed_hints=preprocessed_hints,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -270,12 +311,18 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
), stencils, images
return generated_imgs, text_output, ""
return generated_imgs, text_output, "", stencils, images
with gr.Blocks(title="Image-to-Image") as img2img_web:
# Stencils
# TODO: Add more stencils here
STENCIL_COUNT = 2
stencils = gr.State([None] * STENCIL_COUNT)
images = gr.State([None] * STENCIL_COUNT)
preprocessed_hints = gr.State([None] * STENCIL_COUNT)
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -284,6 +331,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -291,6 +339,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
type="pil",
interactive=True,
sources=["upload"],
)
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (
@@ -337,104 +392,449 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lines=2,
elem_id="negative_prompt_box",
)
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
tool="sketch",
type="pil",
height=300,
)
with gr.Accordion(label="Multistencil Options", open=False):
choices = [
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
def cnet_preview(
model,
input_image,
index,
stencils,
images,
preprocessed_hints,
):
if isinstance(input_image, PIL.Image.Image):
img_dict = {
"background": None,
"layers": [None],
"composite": input_image,
}
input_image = EditorValue(img_dict)
images[index] = input_image
if model:
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(
np.array(input_image["composite"]),
100,
200,
)
preprocessed_hints[index] = Image.fromarray(
result
)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(
np.array(input_image["composite"])
)
preprocessed_hints[index] = Image.fromarray(
result[0]
)
return (
Image.fromarray(result[0]),
stencils,
images,
preprocessed_hints,
)
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(
np.array(input_image["composite"])
)
preprocessed_hints[index] = Image.fromarray(
result
)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "scribble":
preprocessed_hints[index] = input_image[
"composite"
]
return (
input_image["composite"],
stencils,
images,
preprocessed_hints,
)
case _:
preprocessed_hints[index] = None
return (
None,
stencils,
images,
preprocessed_hints,
)
def import_original(original_img, width, height):
resized_img, _, _ = resize_stencil(
original_img, width, height
)
img_dict = {
"background": resized_img,
"layers": [resized_img],
"composite": None,
}
return gr.ImageEditor(
value=EditorValue(img_dict),
crop_size=(width, height),
)
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [data],
"composite": None,
}
return EditorValue(img_dict)
def update_cn_input(
model,
width,
height,
stencils,
images,
preprocessed_hints,
index,
):
if model == None:
stencils[index] = None
images[index] = None
preprocessed_hints[index] = None
return [
gr.ImageEditor(value=None, visible=False),
gr.Image(value=None),
gr.Slider(visible=False),
gr.Slider(visible=False),
gr.Button(visible=False),
gr.Button(visible=False),
stencils,
images,
preprocessed_hints,
]
elif model == "scribble":
return [
gr.ImageEditor(
visible=True,
interactive=True,
show_label=False,
image_mode="RGB",
type="pil",
brush=Brush(
colors=["#000000"],
color_mode="fixed",
default_size=2,
),
),
gr.Image(
visible=True,
show_label=False,
interactive=True,
show_download_button=False,
),
gr.Slider(visible=True, label="Canvas Width"),
gr.Slider(visible=True, label="Canvas Height"),
gr.Button(visible=True),
gr.Button(visible=False),
stencils,
images,
preprocessed_hints,
]
else:
return [
gr.ImageEditor(
visible=True,
image_mode="RGB",
type="pil",
interactive=True,
),
gr.Image(
visible=True,
show_label=False,
interactive=True,
show_download_button=False,
),
gr.Slider(visible=True, label="Input Width"),
gr.Slider(visible=True, label="Input Height"),
gr.Button(visible=False),
gr.Button(visible=True),
stencils,
images,
preprocessed_hints,
]
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
with gr.Column():
cnet_1 = gr.Button(
value="Generate controlnet input"
)
cnet_1_model = gr.Dropdown(
label="Controlnet 1",
value="None",
choices=choices,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=768,
value=512,
step=8,
visible=False,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=768,
value=512,
step=8,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
use_input_img_1 = gr.Button(
value="Use Original Image",
visible=False,
)
cnet_1_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
show_label=True,
label="Input Image",
type="pil",
)
cnet_1_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=True,
)
use_input_img_1.click(
import_original,
[img2img_init_image, canvas_width, canvas_height],
[cnet_1_image],
)
cnet_1_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
)
),
inputs=[
cnet_1_model,
canvas_width,
canvas_height,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_1_image,
cnet_1_output,
canvas_width,
canvas_height,
make_canvas,
use_input_img_1,
stencils,
images,
preprocessed_hints,
],
)
make_canvas.click(
create_canvas,
[canvas_width, canvas_height],
[
cnet_1_image,
],
)
gr.on(
triggers=[cnet_1.click],
fn=(
lambda a, b, s, i, p: cnet_preview(
a, b, 0, s, i, p
)
),
inputs=[
cnet_1_model,
cnet_1_image,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_1_output,
stencils,
images,
preprocessed_hints,
],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
with gr.Column():
cnet_2 = gr.Button(
value="Generate controlnet input"
)
cnet_2_model = gr.Dropdown(
label="Controlnet 2",
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=768,
value=512,
step=8,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=768,
value=512,
step=8,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
use_input_img_2 = gr.Button(
value="Use Original Image",
visible=False,
)
cnet_2_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
type="pil",
show_label=True,
label="Input Image",
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
use_input_img_2.click(
import_original,
[img2img_init_image, canvas_width, canvas_height],
[cnet_2_image],
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
cnet_2_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=True,
)
cnet_2_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
)
),
inputs=[
cnet_2_model,
canvas_width,
canvas_height,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_2_image,
cnet_2_output,
canvas_width,
canvas_height,
make_canvas,
use_input_img_2,
stencils,
images,
preprocessed_hints,
],
)
make_canvas.click(
create_canvas,
[canvas_width, canvas_height],
[
cnet_2_image,
],
)
cnet_2.click(
fn=(
lambda a, b, s, i, p: cnet_preview(
a, b, 1, s, i, p
)
),
inputs=[
cnet_2_model,
cnet_2_image,
stencils,
images,
preprocessed_hints,
],
outputs=[
cnet_2_output,
stencils,
images,
preprocessed_hints,
],
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
@@ -610,16 +1010,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
preprocessed_hints,
],
outputs=[
img2img_gallery,
std_output,
img2img_status,
stencils,
images,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)
@@ -638,3 +1047,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -3,7 +3,15 @@ import torch
import time
import sys
import gradio as gr
import PIL.ImageOps
from PIL import Image
from gradio.components.image_editor import (
Brush,
Eraser,
EditorData,
EditorValue,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -13,6 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -35,11 +47,53 @@ init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def set_image_states(editor_data):
input_mask = editor_data["layers"][0]
# inpaint_inf wants white mask on black background (?), whilst ImageEditor
# delivers black mask on transparent (0 opacity) background
inference_mask = Image.new(
mode="RGB", size=input_mask.size, color=(255, 255, 255)
)
inference_mask.paste(input_mask, input_mask)
inference_mask = PIL.ImageOps.invert(inference_mask)
return (
# we set the ImageEditor data again, because it likes to clear
# the image layers (which include the mask) if the user hasn't
# used the upload button, and we sent it and image
# TODO: work out what is going wrong in that case so we don't have
# to do this
{
"background": editor_data["background"],
"layers": [input_mask],
"composite": None,
},
editor_data["background"],
input_mask,
inference_mask,
)
def reload_image_editor(editor_image, editor_mask):
# we set the ImageEditor data again, because it likes to clear
# the image layers (which include the mask) if the user hasn't
# used the upload button, and we sent it the image
# TODO: work out what is going wrong in that case so we don't have
# to do this
return {
"background": editor_image,
"layers": [editor_mask],
"composite": None,
}
# Exposed to UI.
def inpaint_inf(
prompt: str,
negative_prompt: str,
image_dict,
image,
mask_image,
height: int,
width: int,
inpaint_full_res: bool,
@@ -58,7 +112,7 @@ def inpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: int,
):
@@ -99,9 +153,8 @@ def inpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -120,7 +173,8 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
if (
@@ -164,6 +218,7 @@ def inpaint_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -173,8 +228,6 @@ def inpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
@@ -221,6 +274,9 @@ def inpaint_inf(
with gr.Blocks(title="Inpainting") as inpaint_web:
editor_image = gr.State()
editor_mask = gr.State()
inference_mask = gr.State()
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -229,6 +285,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -236,6 +293,16 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
inpaint_init_image = gr.Sketchpad(
label="Masked Image",
type="pil",
sources=("clipboard", "upload"),
interactive=True,
brush=Brush(
colors=["#000000"],
color_mode="fixed",
),
)
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
@@ -285,39 +352,32 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
lines=2,
elem_id="negative_prompt_box",
)
inpaint_init_image = gr.Image(
label="Masked Image",
source="upload",
tool="sketch",
type="pil",
height=350,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
@@ -441,6 +501,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="gallery",
columns=[2],
object_fit="contain",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{inpaint_model_info}\n"
@@ -477,7 +539,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
inputs=[
prompt,
negative_prompt,
inpaint_init_image,
editor_image,
inference_mask,
height,
width,
inpaint_full_res,
@@ -496,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -507,14 +570,64 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=inpaint_status,
show_progress="none",
)
set_image_states_args = dict(
fn=set_image_states,
inputs=[inpaint_init_image],
outputs=[
inpaint_init_image,
editor_image,
editor_mask,
inference_mask,
],
show_progress="none",
)
reload_image_editor_args = dict(
fn=reload_image_editor,
inputs=[editor_image, editor_mask],
outputs=[inpaint_init_image],
show_progress="none",
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
# all these trigger generation
prompt_submit = (
prompt.submit(**set_image_states_args)
.then(**status_kwargs)
.then(**kwargs)
.then(**reload_image_editor_args)
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
neg_prompt_submit = (
negative_prompt.submit(**set_image_states_args)
.then(**status_kwargs)
.then(**kwargs)
.then(**reload_image_editor_args)
)
generate_click = (
stable_diffusion.click(**set_image_states_args)
.then(**status_kwargs)
.then(**kwargs)
.then(**reload_image_editor_args)
)
# Attempts to cancel generation
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
# Updates LoRA information when one is selected
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -0,0 +1,49 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

View File

@@ -23,6 +23,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -237,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
get_custom_vae_or_lora_weights(lora_weights, "lora"),
],
outputs=[std_output],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -104,8 +104,8 @@ with gr.Blocks() as model_web:
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
show_download_button=False,
)
with gr.Row(visible=False) as sendto_btns:

View File

@@ -3,9 +3,11 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -61,7 +63,7 @@ def outpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -101,9 +103,8 @@ def outpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -122,7 +123,8 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
if (
@@ -164,6 +166,7 @@ def outpaint_inf(
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -237,6 +240,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -244,6 +248,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
outpaint_init_image = gr.Image(
label="Input Image", type="pil", sources=["upload"]
)
with gr.Row():
outpaint_model_info = (
f"Custom Model Path: {str(get_custom_model_path())}"
@@ -291,37 +298,32 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
lines=2,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
@@ -524,7 +526,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -546,3 +548,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -80,27 +80,26 @@ with gr.Blocks() as outputgallery_web:
label="Getting subdirectories...",
value=nod_logo,
interactive=False,
show_download_button=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
columns=4,
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row():
with gr.Group():
with gr.Row(elem_id="output_subdir_container"):
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
@@ -108,7 +107,7 @@ with gr.Blocks() as outputgallery_web:
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
# elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
@@ -148,10 +147,12 @@ with gr.Blocks() as outputgallery_web:
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
col_count=(2, "fixed"),
row_count=(1, "fixed"),
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=False,
)
with gr.Accordion(label="Send To", open=True):
@@ -162,6 +163,12 @@ with gr.Blocks() as outputgallery_web:
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_txt2img_sdxl = gr.Button(
value="Txt2Img XL",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
@@ -195,15 +202,18 @@ with gr.Blocks() as outputgallery_web:
def on_clear_gallery():
return [
gr.Gallery.update(
gr.Gallery(
value=[],
visible=False,
),
gr.Image.update(
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
@@ -212,12 +222,12 @@ with gr.Blocks() as outputgallery_web:
)
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -251,16 +261,16 @@ with gr.Blocks() as outputgallery_web:
)
return [
gr.Dropdown.update(
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -286,12 +296,12 @@ with gr.Blocks() as outputgallery_web:
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -316,12 +326,18 @@ with gr.Blocks() as outputgallery_web:
else:
return [
filename,
list(map(list, params["parameters"].items())),
gr.DataFrame(
value=list(map(list, params["parameters"].items())),
row_count=(len(params["parameters"]), "fixed"),
),
]
return [
filename,
[["Status", "No parameters found"]],
gr.DataFrame(
value=[["Status", "No parameters found"]],
row_count=(1, "fixed"),
),
]
def on_outputgallery_filename_change(filename: str) -> list:
@@ -329,12 +345,12 @@ with gr.Blocks() as outputgallery_web:
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
@@ -365,53 +381,6 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
@@ -423,38 +392,42 @@ with gr.Blocks() as outputgallery_web:
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
@@ -477,16 +450,17 @@ with gr.Blocks() as outputgallery_web:
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
def outputgallery_watch(components: gr.Textbox, queued_components=[]):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
queue=component in queued_components,
show_progress="none",
)

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
@@ -132,27 +133,6 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -161,7 +141,9 @@ def chat(
prompt_prefix,
history,
model,
device,
backend,
devices,
sharded,
precision,
download_vmfb,
config_file,
@@ -173,7 +155,8 @@ def chat(
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
device, device_id = clean_device_info(device)
device, device_id = clean_device_info(devices[0])
no_of_devices = len(devices)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -193,6 +176,11 @@ def chat(
get_vulkan_target_triple,
)
_extra_args = _extra_args + [
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
@@ -234,7 +222,7 @@ def chat(
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
if sharded:
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
@@ -243,6 +231,7 @@ def chat(
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
n_devices=no_of_devices,
)
else:
# if config_file is None:
@@ -270,13 +259,14 @@ def chat(
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
# for text, msg, exec_time in progress.tqdm(
# vicuna_model.generate(prompt, cli=cli),
# desc="generating response",
# ):
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
if msg is None:
if is_first:
prefill_time = exec_time
prefill_time = exec_time / 1000
is_first = False
else:
total_time_ms += exec_time
@@ -397,6 +387,16 @@ def view_json_file(file_obj):
return content
filtered_devices = dict()
def change_backend(backend):
new_choices = gr.Dropdown(
choices=filtered_devices[backend], label=f"{backend} devices"
)
return new_choices
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
@@ -413,15 +413,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
for x in backend_list:
filtered_devices[x] = [y for y in supported_devices if x in y]
print(filtered_devices)
backend = gr.Radio(
label="backend",
value="cpu",
choices=backend_list,
)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
label="cpu devices",
choices=filtered_devices["cpu"],
interactive=True,
allow_custom_value=True,
# multiselect=True,
multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -445,18 +452,23 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
value=False,
interactive=True,
)
sharded = gr.Checkbox(
label="Shard Model",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button = gr.Button(value="View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
chatbot = gr.Chatbot(height=500)
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row():
with gr.Column():
msg = gr.Textbox(
@@ -472,6 +484,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
backend.change(
fn=change_backend,
inputs=[backend],
outputs=[device],
show_progress=False,
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
@@ -484,7 +503,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,
@@ -505,7 +526,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
prompt_prefix,
chatbot,
model,
backend,
device,
sharded,
precision,
download_vmfb,
config_file,

View File

@@ -0,0 +1,661 @@
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_sdxl_models,
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_sdxl_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
if precision != "fp16":
print("currently we support fp16 for SDXL")
precision = "fp16"
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = model_id
if custom_vae:
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img_sdxl",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
global_obj.set_sd_obj(
Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image-SDXL",
current_batch + 1,
batch_count,
batch_size,
)
return generated_imgs, text_output, ""
theme = gr.themes.Glass(
primary_hue="slate",
secondary_hue="gray",
)
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_sdxl_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-xl-base-1.0",
choices=predefined_sdxl_models
+ get_custom_model_files(
custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=11,
)
t2i_sdxl_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_sdxl_vae_info = (
f"VAE Path: {t2i_sdxl_vae_info}"
)
custom_vae = gr.Dropdown(
label=f"VAE Models",
info=t2i_sdxl_vae_info,
elem_id="custom_model",
value="madebyollin/sdxl-vae-fp16-fix",
choices=[
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files(
"vae", custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=4,
)
txt2img_sdxl_png_info_img = gr.Image(
scale=1,
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
visible=True,
sources=["upload"],
)
with gr.Group(elem_id="prompt_box_outer"):
txt2img_sdxl_autogen = gr.Checkbox(
label="Auto-Generate Images",
value=False,
visible=False,
)
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
lora_weights = gr.Dropdown(
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="EulerDiscrete",
choices=[
"DDIM",
"EulerAncestralDiscrete",
"EulerDiscrete",
"LCMScheduler",
],
allow_custom_value=True,
visible=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
512,
1024,
value=768,
step=256,
label="Height",
visible=True,
interactive=True,
)
width = gr.Slider(
512,
1024,
value=768,
step=256,
label="Width",
visible=True,
interactive=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=77,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="Guidance Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_sdxl_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="scale_down",
)
std_output = gr.Textbox(
value=f"{t2i_sdxl_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_sdxl_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
txt2img_sdxl_sendto_img2img = gr.Button(
value="Send To Img2Img",
visible=False,
)
txt2img_sdxl_sendto_inpaint = gr.Button(
value="Send To Inpaint",
visible=False,
)
txt2img_sdxl_sendto_outpaint = gr.Button(
value="Send To Outpaint",
visible=False,
)
txt2img_sdxl_sendto_upscaler = gr.Button(
value="Send To Upscaler",
visible=False,
)
kwargs = dict(
fn=txt2img_sdxl_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_sdxl_custom_model,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_strength,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
queue=True,
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
concurrency_limit=1,
)
def autogen_changed(checked):
if checked:
args.autogen = True
else:
args.autogen = False
def check_last_input(prompt):
if not prompt.endswith(" "):
return True
elif not args.autogen:
return True
else:
return False
auto_gen_kwargs = dict(
fn=check_last_input,
inputs=[negative_prompt],
outputs=[txt2img_sdxl_status],
concurrency_limit=1,
)
txt2img_sdxl_autogen.change(
fn=autogen_changed,
inputs=[txt2img_sdxl_autogen],
outputs=None,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[
prompt_submit,
neg_prompt_submit,
generate_click,
],
)
txt2img_sdxl_png_info_img.change(
fn=import_png_metadata,
inputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
custom_vae,
],
outputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
custom_vae,
],
)
txt2img_sdxl_custom_model.change(
fn=set_model_default_configs,
inputs=[
txt2img_sdxl_custom_model,
],
outputs=[
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
width,
height,
custom_vae,
txt2img_sdxl_autogen,
],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -1,10 +1,13 @@
import json
import os
import warnings
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -15,6 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -33,6 +40,34 @@ from apps.stable_diffusion.src.utils import (
resampler_list,
)
# Names of all interactive fields that can be edited by user
all_gradio_labels = [
"txt2img_custom_model",
"custom_vae",
"prompt",
"negative_prompt",
"lora_weights",
"lora_strength",
"scheduler",
"save_metadata_to_png",
"save_metadata_to_json",
"height",
"width",
"steps",
"guidance_scale",
"Low VRAM",
"use_hiresfix",
"resample_type",
"hiresfix_height",
"hiresfix_width",
"hiresfix_strength",
"batch_count",
"batch_size",
"repeatable_seeds",
"seed",
"device",
]
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
@@ -59,7 +94,7 @@ def txt2img_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
@@ -106,9 +141,8 @@ def txt2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -124,7 +158,8 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
if (
@@ -175,6 +210,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -224,7 +260,8 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil="None",
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -256,6 +293,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -278,7 +316,9 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
stencils=[],
images=None,
control_mode=None,
resample_type=resample_type,
)
total_time = time.time() - start_time
@@ -300,7 +340,92 @@ def txt2img_inf(
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image") as txt2img_web:
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# This function export values for all fields that can be edited by user to the settings.json file in ui folder
def export_settings(*values):
settings_list = list(zip(all_gradio_labels, values))
settings = {}
for label, value in settings_list:
settings[label] = value
settings = {"txt2img": settings}
with open("./ui/settings.json", "w") as json_file:
json.dump(settings, json_file, indent=4)
# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder
def load_settings():
try:
with open("./ui/settings.json", "r") as json_file:
loaded_settings = json.load(json_file)["txt2img"]
except (FileNotFoundError, KeyError):
warnings.warn(
"Settings.json file not found or 'txt2img' key is missing. Using default values for fields."
)
loaded_settings = (
{}
) # json file not existing or the data wasn't saved yet
return [
loaded_settings.get(
"txt2img_custom_model",
os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
),
loaded_settings.get(
"custom_vae",
os.path.basename(args.custom_vae) if args.custom_vae else "None",
),
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_strength", args.lora_strength),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
),
loaded_settings.get(
"save_metadata_to_json", args.save_metadata_to_json
),
loaded_settings.get("height", args.height),
loaded_settings.get("width", args.width),
loaded_settings.get("steps", args.steps),
loaded_settings.get("guidance_scale", args.guidance_scale),
loaded_settings.get("Low VRAM", args.ondemand),
loaded_settings.get("use_hiresfix", args.use_hiresfix),
loaded_settings.get("resample_type", args.resample_type),
loaded_settings.get("hiresfix_height", args.hiresfix_height),
loaded_settings.get("hiresfix_width", args.hiresfix_width),
loaded_settings.get("hiresfix_strength", args.hiresfix_strength),
loaded_settings.get("batch_count", args.batch_count),
loaded_settings.get("batch_size", args.batch_size),
loaded_settings.get("repeatable_seeds", args.repeatable_seeds),
loaded_settings.get("seed", args.seed),
loaded_settings.get("device", available_devices[0]),
]
# This function loads the user's exported default settings on the start of program
def onload_load_settings():
loaded_data = load_settings()
structured_data = settings_list = list(zip(all_gradio_labels, loaded_data))
return dict(structured_data)
default_settings = onload_load_settings()
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -309,6 +434,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -317,20 +443,20 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Column():
with gr.Row():
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
value=default_settings.get(
"txt2img_custom_model"
),
choices=get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
scale=2,
scale=11,
)
# janky fix for overflowing text
t2i_vae_info = (
@@ -341,93 +467,101 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
value=default_settings.get("custom_vae"),
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=4,
)
txt2img_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
visible=True,
sources=["upload"],
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
tool="None",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
value=default_settings.get("prompt"),
lines=2,
elem_id="prompt_box",
)
# TODO: coming soon
autogen = gr.Checkbox(
label="Continuous Generation",
visible=False,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
value=default_settings.get("negative_prompt"),
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=default_settings.get("lora_strength"),
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
value=default_settings.get("scheduler"),
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
value=default_settings.get(
"save_metadata_to_png"
),
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
value=default_settings.get(
"save_metadata_to_json"
),
interactive=True,
)
with gr.Row():
height = gr.Slider(
384,
768,
value=args.height,
value=default_settings.get("height"),
step=8,
label="Height",
)
width = gr.Slider(
384,
768,
value=args.width,
value=default_settings.get("width"),
step=8,
label="Width",
)
@@ -452,18 +586,22 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
1,
100,
value=default_settings.get("steps"),
step=1,
label="Steps",
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
value=default_settings.get("guidance_scale"),
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
value=default_settings.get("Low VRAM"),
label="Low VRAM",
interactive=True,
)
@@ -472,7 +610,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
value=default_settings.get("batch_count"),
step=1,
label="Batch Count",
interactive=True,
@@ -483,23 +621,24 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
4,
value=args.batch_size,
step=1,
label="Batch Size",
label=default_settings.get("batch_size"),
interactive=True,
visible=False,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
default_settings.get("repeatable_seeds"),
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
value=default_settings.get("use_hiresfix"),
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
value=default_settings.get("resample_type"),
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
@@ -507,34 +646,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
value=default_settings.get("hiresfix_height"),
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
value=default_settings.get("hiresfix_width"),
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
value=default_settings.get("hiresfix_strength"),
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
value=default_settings.get("seed"),
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
value=default_settings.get("device"),
choices=available_devices,
allow_custom_value=True,
)
@@ -585,6 +724,75 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
with gr.Row():
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Load Default Settings"
)
export_defaults.click(
fn=load_settings,
inputs=[],
outputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
)
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Export Default Settings"
)
export_defaults.click(
fn=export_settings,
inputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
outputs=[],
)
kwargs = dict(
fn=txt2img_inf,
@@ -607,7 +815,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
use_hiresfix,
@@ -650,7 +858,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
@@ -665,7 +872,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
lora_strength,
custom_vae,
],
)
@@ -673,12 +880,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list,
value="SharkEulerDiscrete",
)
@@ -689,3 +896,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
outputs=[scheduler],
queue=False,
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -12,6 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -51,7 +56,7 @@ def upscaler_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -98,9 +103,8 @@ def upscaler_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -118,7 +122,8 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
if (
@@ -157,6 +162,7 @@ def upscaler_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -253,6 +259,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=50,
@@ -260,6 +267,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
upscaler_init_image = gr.Image(
label="Input Image",
type="pil",
sources=["upload"],
)
with gr.Row():
upscaler_model_info = (
f"Custom Model Path: {str(get_custom_model_path())}"
@@ -308,37 +320,32 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
lines=2,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
@@ -515,7 +522,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -537,3 +544,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -1,10 +1,19 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
import math
import json
import safetensors
import gradio as gr
import PIL.Image as Image
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
from enum import IntEnum
from gradio.components.image_editor import EditorValue
from apps.stable_diffusion.src import get_available_devices
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
@@ -24,10 +33,20 @@ class Config:
width: int
device: str
use_lora: str
use_stencil: str
lora_strength: float
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
@@ -49,9 +68,11 @@ scheduler_list_cpu_only = [
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
"LCMScheduler",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
"SharkEulerAncestralDiscrete",
]
predefined_models = [
@@ -72,6 +93,10 @@ predefined_paint_models = [
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
predefined_sdxl_models = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
def resource_path(relative_path):
@@ -125,6 +150,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
)
]
match custom_checkpoint_type:
case "sdxl":
files = [
val
for val in files
if any(x in val for x in ["XL", "xl", "Xl"])
]
case "inpainting":
files = [
val
@@ -150,17 +181,82 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
def get_custom_vae_or_lora_weights(weights, model):
if weights == "None":
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
custom_weights = get_custom_model_pathfile(str(weights), model)
if os.path.isfile(custom_weights):
use_weight = custom_weights
else:
use_weight = weights
return use_weight
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_custom_model_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
@@ -169,6 +265,116 @@ def cancel_sd():
pass
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
import gradio as gr
config_modelname = default_config_exists(model_ckpt_or_id)
if jsonconfig:
return get_config_from_json(jsonconfig)
elif config_modelname:
return default_configs[config_modelname]
# TODO: Use HF metadata to setup pipeline if available
# elif is_valid_hf_id(model_ckpt_or_id):
# return get_HF_default_configs(model_ckpt_or_id)
else:
# We don't have default metadata to setup a good config. Do not change configs.
return [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
]
def get_config_from_json(model_ckpt_or_id, jsonconfig):
# TODO: make this work properly. It is currently not user-exposed.
cfgdata = json.load(jsonconfig)
return [
cfgdata["prompt_box_behavior"],
cfgdata["neg_prompt_box_behavior"],
cfgdata["steps"],
cfgdata["scheduler"],
cfgdata["guidance_scale"],
cfgdata["width"],
cfgdata["height"],
cfgdata["custom_vae"],
]
def default_config_exists(model_ckpt_or_id):
if model_ckpt_or_id in default_configs.keys():
return model_ckpt_or_id
elif "turbo" in model_ckpt_or_id.lower():
return "stabilityai/sdxl-turbo"
else:
return None
def mask_editor_value_for_image_file(filepath):
image = Image.open(filepath)
mask = Image.new(mode="RGBA", size=image.size, color=(0, 0, 0, 0))
return {"background": image, "layers": [mask], "composite": image}
def mask_editor_value_for_gallery_data(gallery_data):
filepath = (
gallery_data.root[0].image.path
if len(gallery_data.root) != 0
else None
)
if os.path.isfile(filepath):
return mask_editor_value_for_image_file(filepath)
return EditorValue()
default_configs = {
"stabilityai/sdxl-turbo": [
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
),
gr.Slider(0, 10, value=2),
"EulerAncestralDiscrete",
gr.Slider(0, value=0),
512,
512,
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate", visible=False, interactive=True, value=False
),
],
"stabilityai/stable-diffusion-xl-base-1.0": [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
40,
"EulerDiscrete",
7.5,
gr.Slider(value=768, interactive=True),
gr.Slider(value=768, interactive=True),
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
],
}
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()

View File

@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
) -> tuple[str, float]:
lora_custom = ""
lora_strength = 1.0
if key in metadata:
lora_file = metadata[key]
split_metadata = metadata[key].split(":")
lora_file = split_metadata[0]
if len(split_metadata) == 2:
try:
lora_strength = float(split_metadata[1])
except ValueError:
pass
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
lora_custom = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
return lora_custom, lora_strength
def import_png_metadata(
@@ -150,7 +156,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
@@ -160,9 +165,10 @@ def import_png_metadata(
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
(
custom_lora,
custom_lora_strength,
) = find_lora_from_png_metadata("LoRA", metadata)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
@@ -177,12 +183,8 @@ def import_png_metadata(
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
if "LoRA" in metadata and not custom_lora:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
@@ -215,6 +217,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_lora_strength,
custom_vae,
)

View File

@@ -5,11 +5,25 @@ from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
def clear_tmp_mlir():
cleanup_start = time()
print(
"Clearing .mlir temporary files from a prior run. This may take some time..."
)
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
print(
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
)
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -78,7 +78,10 @@ def test_loop(
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
@@ -112,6 +115,8 @@ def test_loop(
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this

View File

@@ -23,6 +23,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
value=nod_logo,
show_label=False,
interactive=False,
show_download_button=False,
elem_id="top_logo",
width=150,
height=100,

View File

@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:

View File

@@ -6,8 +6,8 @@ requires = [
"numpy>=1.22.4",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
"iree-compiler==20231212.*",
"iree-runtime==20231212.*",
]
build-backend = "setuptools.build_meta"

View File

@@ -26,7 +26,7 @@ sacremoses
sentencepiece
# web dependecies.
gradio
gradio==3.44.3
altair
scipy

View File

@@ -26,7 +26,7 @@ diffusers
accelerate
scipy
ftfy
gradio==3.44.3
gradio==4.12.0
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
@@ -50,4 +50,8 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
# For quantized GPTQ models
optimum
auto_gptq

View File

@@ -11,8 +11,8 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
"iree-compiler==20231212.*",
"iree-runtime==20231212.*",
]
setup(

View File

@@ -7,13 +7,13 @@
It checks the Python version installed and installs any required build
dependencies into a Python virtual environment.
If that environment does not exist, it creates it.
.PARAMETER update-src
git pulls latest version
.PARAMETER force
removes and recreates venv to force update of all dependencies
.EXAMPLE
.\setup_venv.ps1 --force
@@ -39,11 +39,11 @@ if ($arguments -eq "--force"){
Write-Host "deactivating..."
Deactivate
}
if (Test-Path .\shark.venv\) {
if (Test-Path .\shark1.venv\) {
Write-Host "removing and recreating venv..."
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Remove-Item .\shark1.venv -Force -Recurse
if (Test-Path .\shark1.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
exit 1
}
@@ -83,15 +83,15 @@ if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark1.venv\}
else {python -m venv .\shark1.venv\}
.\shark1.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler==20231212.* iree-runtime==20231212.*
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
Write-Host "Source your venv with ./shark1.venv/Scripts/activate"

View File

@@ -5,7 +5,7 @@
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -35,7 +35,7 @@ fi
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
VENV_DIR=${VENV_DIR:-shark1.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
@@ -111,7 +111,7 @@ else
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler==20231212.* iree-runtime==20231212.*
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi

View File

@@ -31,60 +31,63 @@ from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] == "vulkan":
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
@@ -351,11 +354,15 @@ def get_iree_module(
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config.id = hal_device_id
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
@@ -394,15 +401,16 @@ def load_vmfb_using_mmap(
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
config.id = hal_device_id
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)

View File

@@ -95,6 +95,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_num = 0 if device_num is None else device_num
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
@@ -103,7 +104,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print(f"Found ROCm device arch : {arch_in_device_dump}")
return arch_in_device_dump
default_rocm_arch = "gfx_1100"
default_rocm_arch = "gfx1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."

View File

@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
if isinstance(device_num, int):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: vulkan://{device_num}")
vulkan_device_name = vulkaninfo_list[device_num]
else:
from iree.runtime import get_driver
vulkan_device_driver = get_driver(device_num)
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
print(vulkan_device_name)
return vulkan_device_name
def get_os_name():
@@ -130,8 +139,16 @@ def get_vulkan_target_triple(device_name):
triple = f"rdna3-780m-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
elif "7600" in device_name:
triple = f"rdna3-7600-{system_os}"
elif "7700" in device_name:
triple = f"rdna3-7700-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon") and any(x in device_name for x in "6700", "6750"):
triple = f"rdna2-unknown-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon") and "7" in device_name:
triple = f"rdna3-unknown-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon"):
triple = f"rdna3-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
@@ -174,6 +191,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
res_vulkan_flag += [
"--iree-stream-resource-max-allocation-size=3221225472"
]
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
@@ -195,7 +215,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}",
f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}"
f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -7,7 +7,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -84,7 +84,7 @@ def compile_int_precision(
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,

View File

@@ -800,15 +800,17 @@ def save_mlir(
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
dir="",
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):
os.makedirs(dir)
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)

View File

@@ -0,0 +1,62 @@
import unittest
from unittest.mock import mock_open, patch
from apps.stable_diffusion.web.ui.txt2img_ui import (
export_settings,
load_settings,
all_gradio_labels,
)
class TestExportSettings(unittest.TestCase):
@patch("builtins.open", new_callable=mock_open)
@patch("json.dump")
def test_export_settings(self, mock_json_dump, mock_file):
test_values = ["value1", "value2", "value3"]
expected_output = {
"txt2img": {
label: value
for label, value in zip(all_gradio_labels, test_values)
}
}
export_settings(*test_values)
mock_file.assert_called_once_with("./ui/settings.json", "w")
mock_json_dump.assert_called_once_with(
expected_output, mock_file(), indent=4
)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"txt2img": {"some_setting": "some_value"}}',
)
def test_load_settings_file_exists(self, mock_file, mock_json_load):
mock_json_load.return_value = {
"txt2img": {
"txt2img_custom_model": "custom_model_value",
"custom_vae": "custom_vae_value",
}
}
settings = load_settings()
self.assertEqual(settings[0], "custom_model_value")
self.assertEqual(settings[1], "custom_vae_value")
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", side_effect=FileNotFoundError)
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", new_callable=mock_open, read_data="{}")
def test_load_settings_key_error(self, mock_file, mock_json_load):
mock_json_load.return_value = {}
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)

View File

@@ -1,21 +1,19 @@
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
1 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False False
2 bert-base-uncased_fp16 bert-large-uncased linalg torch 1e-1 1e-2 1e-1 1e-3 default None True False True True False False
3 bert-large-uncased facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default None nhcw-nhwc False True False False Fails during iree-compile.
4 facebook/deit-small-distilled-patch16-224 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True True False True True Fails during iree-compile. https://github.com/nod-ai/SHARK/issues/311
5 google/vit-base-patch16-224 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/311 https://github.com/nod-ai/SHARK/issues/390 macos
6 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default nhcw-nhwc None False True False False https://github.com/nod-ai/SHARK/issues/390 macos
7 microsoft/MiniLM-L12-H384-uncased google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False False https://github.com/nod-ai/SHARK/issues/344 macos
8 google/mobilebert-uncased mobilenet_v3_small linalg torch 1e-2 1e-1 1e-3 1e-2 default None nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/344 https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc True True True https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
9 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True True True https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487 macos
10 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
11 resnet18 linalg torch 1e-2 1e-3 default None True True False True True macos
12 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True True True Numerics issues, awaiting cuda-independent fp16 integration
13 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
14 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
15 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True False True True False False macos
16 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
17 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
18 t5-base linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported. macos
19 t5-large linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported macos

View File

@@ -50,7 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
is_decompose = row[5]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
is_dynamic = False
print("generating artifacts for: " + torch_model_name)
model = None
input = None
@@ -104,7 +104,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model_name=torch_model_name,
mlir_type=mlir_type,
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
)
else:
mlir_importer = SharkImporter(
@@ -114,7 +114,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
@@ -123,7 +123,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,