Compare commits

...

79 Commits

Author SHA1 Message Date
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
Eliasj42
4015793f84 changed method of compiling vicuna to remove first and second vicuna (#1611)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-03 12:12:43 -07:00
Ean Garvey
d63ce76dd8 Use sortable image filenames for SD outputs. (#1528) 2023-07-03 10:30:47 -07:00
Prashant Kumar
1c32915570 Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615) 2023-07-01 08:30:58 -07:00
Ean Garvey
6d286c0609 Enable tuning for rectangle sizes on rdna2. (#1608) 2023-06-30 22:28:24 -07:00
Stefan Kapusniak
7392b22731 UI/Web Reduce animation of default --progress_bars (#1613) 2023-06-30 21:12:10 -07:00
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
powderluv
5ce6001b41 Update stablelm_ui.py to default to fp16 2023-06-23 22:55:47 -07:00
powderluv
501d0ca52e Add sentencepiece to webui for pyinstaller 2023-06-23 22:52:06 -07:00
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
81 changed files with 4592 additions and 1966 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py

View File

@@ -54,8 +54,8 @@ jobs:
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets

View File

@@ -35,6 +35,8 @@ jobs:
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
@@ -46,6 +48,8 @@ jobs:
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
@@ -61,7 +65,6 @@ jobs:
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
@@ -84,9 +87,6 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -129,15 +129,14 @@ jobs:
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

4
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -157,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,39 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
@@ -22,12 +47,34 @@ class FirstVicuna(torch.nn.Module):
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision="fp32", weight_group_size=128):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
def forward(
self,

View File

@@ -62,7 +62,105 @@ class SecondVicunaLayer(torch.nn.Module):
)
class CompiledFirstVicunaLayer(torch.nn.Module):
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
output = self.model(input_ids)
return output
class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
@@ -76,103 +174,55 @@ class CompiledFirstVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)

View File

@@ -28,8 +28,9 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -40,7 +41,12 @@ parser.add_argument(
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
@@ -59,12 +65,12 @@ class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
falcon_mlir_path=None,
falcon_vmfb_path=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -85,7 +91,7 @@ class Falcon(SharkLLMBase):
return tokenizer
def get_src_model(self):
print("Loading src model")
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -93,9 +99,26 @@ class Falcon(SharkLLMBase):
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
@@ -106,27 +129,26 @@ class Falcon(SharkLLMBase):
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if not mlir_generated:
compilation_input_ids = torch.randint(
@@ -184,6 +206,7 @@ class Falcon(SharkLLMBase):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
@@ -192,17 +215,6 @@ class Falcon(SharkLLMBase):
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
@@ -375,6 +387,8 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
@@ -428,18 +442,35 @@ if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
@@ -451,11 +482,16 @@ if __name__ == "__main__":
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("Falcon Model: ", falcon.model_name)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? True or False?: "
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt:
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
@@ -469,5 +505,8 @@ if __name__ == "__main__":
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? True or False?: "
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

View File

@@ -1,559 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def compile_first_vicuna(self):
vmfb = get_vmfb_from_path(
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
module = str(module)
new_lines = []
print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
)
if (
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
in line
):
continue
new_lines.append(line)
module = "\n".join(new_lines)
print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
module = BytesIO(module)
bytecode = module.read()
del module
print(f"[DEBUG] writing mlir to file")
f_ = open(self.first_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.first_vicuna_vmfb_path.parent.absolute(),
self.first_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved first vic vmfb at vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile_second_vicuna(self):
vmfb = get_vmfb_from_path(
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module_str = str(module)
new_lines = []
for line in module_str.splitlines():
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
module_str = "\n".join(new_lines)
bytecode = module_str.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(self.second_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.second_vicuna_vmfb_path.parent.absolute(),
self.second_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved vmfb at ", str(path))
shark_module.load_module(self.second_vicuna_vmfb_path)
# self.shark_module = shark_module
return shark_module
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
if token == 2:
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
firstVicunaInput = (input_ids,)
assert first_vic is not None
output_first_vicuna = first_vic("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
if cache_outputs:
torch.save(
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
)
torch.save(
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
)
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
)
return token, logits_first_vicuna, output_first_vicuna_tensor
def forward_second(sec_vic, inputs=None, load_inputs=False):
if inputs is not None:
logits = inputs[0]
pkv = inputs[1:]
elif load_inputs:
pkv = torch.load("output_first_vicuna_tensor.pt")
pkv = tuple(torch.tensor(x) for x in pkv)
logits = torch.load("logits_first_vicuna_tensor.pt")
else:
print(
"Either inputs must be given, or load_inputs must be true"
)
return None
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
new_pkv = secondVicunaOutput[1:]
new_logits = secondVicunaOutput[0]
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
return new_token, new_logits, new_pkv
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
fv = params["fv"]
token, logits, pkv = forward_first(
fv, # self.shark_model[0],
prompt=prompt,
cache_outputs=False,
)
else:
_logits = params["logits"]
_pkv = params["pkv"]
inputs = (_logits,) + tuple(_pkv)
sv = params["sv"]
token, logits, pkv = forward_second(
sv, # self.shark_model[1],
inputs=inputs,
load_inputs=False,
)
detok = self.tokenizer.decode(token)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"
f" token : {token} | detok : {detok}"
)
ret_dict = {
"token": token,
"logits": logits,
"pkv": pkv,
"detok": detok,
}
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -1,408 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
ShardedVicunaModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
from io import BytesIO
from pathlib import Path
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode
def compile_to_vmfb(self, inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0, layers0, is_first=True
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
def compile(self):
return self.get_sharded_model()
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
tokens_generated = []
_past_key_values = None
_token = None
detoks_generated = []
for iteration in range(self.max_num_tokens):
params = {
"prompt": prompt,
"is_first": iteration == 0,
"token": _token,
"past_key_values": _past_key_values,
}
generated_token_op = self.generate_new_token(params=params)
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
if _token == 2:
break
detoks_generated.append(_detok)
tokens_generated.append(_token)
for i in range(len(tokens_generated)):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)
ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time

View File

@@ -81,6 +81,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -79,6 +79,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -223,7 +223,8 @@ def lora_train(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:

View File

@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
@@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,

View File

@@ -61,6 +61,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
# When the models get uploaded, it should be defaulted to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -73,6 +73,7 @@ if __name__ == "__main__":
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -14,24 +14,29 @@ datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += copy_metadata('sentencepiece')
datas += collect_data_files('tokenizers')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('opencv_python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('google_cloud_storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += collect_data_files('jsonschema')
datas += collect_data_files('jsonschema_specifications')
datas += collect_data_files('cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
@@ -47,6 +52,7 @@ block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("transformers") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
@@ -73,11 +79,11 @@ exe = EXE(
a.zipfiles,
a.datas,
[],
name='shark_sd',
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,

View File

@@ -29,6 +29,7 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('py-cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -163,14 +172,24 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"vae",
"vae_encode",
"stencil_adaptor",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
@@ -197,7 +216,11 @@ class SharkifyStableDiffusionModel:
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -209,10 +232,12 @@ class SharkifyStableDiffusionModel:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
@@ -226,7 +251,11 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
@@ -243,7 +272,13 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
@@ -279,7 +314,11 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
@@ -303,7 +342,10 @@ class SharkifyStableDiffusionModel:
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -316,12 +358,43 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
@@ -342,7 +415,25 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
@@ -386,16 +477,23 @@ class SharkifyStableDiffusionModel:
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
@@ -415,9 +513,14 @@ class SharkifyStableDiffusionModel:
)
return shark_cnet, cnet_mlir
def get_unet(self):
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -426,17 +529,26 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -452,17 +564,33 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -471,15 +599,17 @@ class SharkifyStableDiffusionModel:
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -502,17 +632,27 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -520,7 +660,12 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
@@ -528,7 +673,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -567,34 +714,47 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model):
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet()
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
@@ -616,25 +776,35 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def unet(self):
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
@@ -657,7 +827,11 @@ class SharkifyStableDiffusionModel:
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
@@ -675,7 +849,9 @@ class SharkifyStableDiffusionModel:
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -17,9 +17,13 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
@@ -27,9 +31,12 @@ def get_quantize_model():
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]

View File

@@ -15,6 +15,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -135,6 +145,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
@@ -156,7 +167,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -37,6 +42,11 @@ class InpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -378,6 +388,7 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -408,7 +419,10 @@ class InpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class OutpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -379,6 +389,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -409,7 +420,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,12 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +44,12 @@ class StencilPipeline(StableDiffusionPipeline):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -204,6 +216,7 @@ class StencilPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
@@ -230,7 +243,10 @@ class StencilPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -13,6 +13,10 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -81,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -112,7 +121,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -17,6 +17,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -67,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
@@ -78,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -168,7 +180,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -182,15 +197,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -219,6 +245,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -243,6 +270,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -264,7 +292,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# 4. Preprocess image

View File

@@ -15,6 +15,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -57,6 +64,7 @@ class StableDiffusionPipeline:
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
@@ -66,7 +74,8 @@ class StableDiffusionPipeline:
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
@@ -81,7 +90,8 @@ class StableDiffusionPipeline:
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
@@ -114,6 +124,24 @@ class StableDiffusionPipeline:
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
@@ -203,7 +231,10 @@ class StableDiffusionPipeline:
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -222,16 +253,28 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -254,6 +297,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -275,6 +319,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
@@ -359,16 +407,21 @@ class StableDiffusionPipeline:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
@@ -387,9 +440,11 @@ class StableDiffusionPipeline:
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
@@ -402,16 +457,43 @@ class StableDiffusionPipeline:
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
@@ -446,7 +528,8 @@ re_attention = re.compile(
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12

View File

@@ -8,6 +8,9 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
@@ -38,9 +41,28 @@ def get_schedulers(model_id):
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
@@ -62,5 +84,21 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -37,4 +37,5 @@ from apps.stable_diffusion.src.utils.utils import (
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -5,4 +5,7 @@
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]

View File

@@ -116,7 +116,7 @@ def load_lower_configs(base_model_id=None):
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
@@ -125,10 +125,38 @@ def load_lower_configs(base_model_id=None):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna3"] and version in [
"v2_1",
"v2_1base",
"v1_4",
"v1_5",
]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.max_length}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
@@ -173,9 +201,22 @@ def dump_after_mlir(input_mlir, use_winograd):
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
)
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
dump_module = ireec.compile_str(
input_mlir,

View File

@@ -19,48 +19,56 @@ p = argparse.ArgumentParser(
)
##############################################################################
### Stable Diffusion Params
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=-1,
help="the seed to use. -1 for a random one.",
help="The seed to use. -1 for a random one.",
)
p.add_argument(
@@ -68,7 +76,7 @@ p.add_argument(
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
@@ -76,7 +84,7 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
help="The height of the output image.",
)
p.add_argument(
@@ -84,77 +92,86 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
help="The strength of change applied on the given input image for "
"img2img.",
)
##############################################################################
### Stable Diffusion Training Params
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
help="The number of steps to train.",
)
##############################################################################
### Inpainting and Outpainting Params
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
@@ -162,7 +179,7 @@ p.add_argument(
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
help="Number of pixels for only masked padding.",
)
p.add_argument(
@@ -170,7 +187,7 @@ p.add_argument(
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
@@ -178,89 +195,92 @@ p.add_argument(
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
help="If expend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
help="If expend right for outpainting.",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
help="If expend top for outpainting.",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
help="If expend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
### Model Config and Usage Params
# Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
"--device", type=str, default="vulkan", help="Device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
@@ -274,28 +294,34 @@ p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
help="Number of batch to be generated with random seeds in "
"single execution.",
)
p.add_argument(
@@ -309,7 +335,8 @@ p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
@@ -323,14 +350,15 @@ p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
@@ -343,209 +371,233 @@ p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
help="Load and unload models for low VRAM.",
)
##############################################################################
### IREE - Vulkan supported flags
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
help="Profiles vulkan device and collects the .rdc info.",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
help="Flag for disabling vulkan validation layers when benchmarking.",
)
##############################################################################
### Misc. Debug and Optimization flags
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
##############################################################################
### Web UI flags
# Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
help="Flag for enabling rest API.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
##############################################################################
### SD model auto-annotation flags
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
@@ -559,31 +611,31 @@ p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
help="Save annotated mlir file.",
)
##############################################################################
### SD model auto-tuner flags
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)

View File

@@ -18,6 +18,7 @@ from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
@@ -31,6 +32,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
import requests
from io import BytesIO
from omegaconf import OmegaConf
from cpuinfo import get_cpu_info
def get_extended_name(model_name):
@@ -47,6 +49,7 @@ def get_vmfb_path_name(model_name):
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
model = "unet" if "unet512" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -78,7 +81,9 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
def get_shark_model(tank_url, model_name, extra_args=None):
if extra_args is None:
extra_args = []
from shark.parser import shark_args
# Set local shark_tank cache directory.
@@ -110,12 +115,15 @@ def compile_through_fx(
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
extra_args=None,
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if extra_args is None:
extra_args = []
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
@@ -145,7 +153,10 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
if (
"unet" in model_name.split("_")[0]
or "unet_512" in model_name.split("_")[0]
):
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
@@ -153,7 +164,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:
@@ -198,13 +209,15 @@ def get_device_mapping(driver, key_combination=3):
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
@@ -218,7 +231,7 @@ def get_device_mapping(driver, key_combination=3):
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
@@ -231,10 +244,12 @@ def get_device_mapping(driver, key_combination=3):
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
@@ -242,7 +257,8 @@ def map_device_to_name_path(device, key_combination=3):
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
@@ -265,10 +281,21 @@ def set_init_device_flags():
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
f"Found device {device_name}. Using target triple "
f"{args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -293,13 +320,24 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or (args.height == 512 and args.width not in [512, 768])
or (args.height == 768 and args.width not in [512, 768])
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
args.height != args.width
and "rdna2" in args.iree_vulkan_target_triple
and base_model_id
not in [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
]
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
@@ -354,7 +392,8 @@ def set_init_device_flags():
if args.use_tuned:
print(
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
f"Using tuned models for {base_model_id}(fp16) on "
f"device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -412,8 +451,12 @@ def get_available_devices():
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -421,9 +464,14 @@ def get_available_devices():
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
@@ -500,10 +548,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
@@ -524,7 +572,10 @@ def convert_original_vae(vae_checkpoint):
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
@@ -645,7 +696,7 @@ def update_lora_weight(model, use_lora, model_name):
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# helps to maintain mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
@@ -662,13 +713,15 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
# Update JSON data to contain an entry mapping model_to_run with
# base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
@@ -687,7 +740,8 @@ def clear_all():
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# Temporary workaround of deleting yaml files to incorporate
# diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
@@ -716,7 +770,9 @@ def get_generated_imgs_todays_subdir() -> str:
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
@@ -724,14 +780,20 @@ def save_output_img(output_img, img_seed, extra_info={}):
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -742,17 +804,30 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
f"{args.prompts[0]}"
f"\nNegative prompt: {args.negative_prompts[0]}"
f"\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
f"[ERROR] Format {args.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
@@ -766,12 +841,17 @@ def save_output_img(output_img, img_seed, extra_info={}):
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
@@ -785,16 +865,27 @@ def save_output_img(output_img, img_seed, extra_info={}):
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={args.steps}, "
f"guidance_scale={args.guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={args.height}x{args.width}, "
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"
)
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.

View File

@@ -1,7 +1,13 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers # ensures inclusion in pysintaller exe generation
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
import torch_mlir
import shutil
import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -20,11 +26,16 @@ def launch_app(address):
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)
@@ -38,6 +49,7 @@ if __name__ == "__main__":
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -49,23 +61,25 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
sys.exit(0)
import gradio as gr
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
config_gradio_tmp_imgs_folder,
)
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):

View File

@@ -231,6 +231,16 @@ footer {
display: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* output gallery tab */
.output_parameters_dataframe tbody td {
font-size: small;

View File

@@ -104,7 +104,8 @@ def img2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -132,7 +133,8 @@ def img2img_inf(
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
f"Shark schedulers are not supported. Switching to EulerDiscrete "
f"scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -249,6 +251,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
@@ -307,7 +310,9 @@ def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
@@ -340,6 +345,10 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -362,8 +371,14 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=i2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -374,13 +389,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
"\\", "\n\\"
)
i2i_vae_info = f"VAE Path: {i2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=i2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -392,13 +417,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -470,15 +495,24 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -601,7 +635,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -645,7 +680,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
ondemand,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(

View File

@@ -92,7 +92,8 @@ def inpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -204,6 +205,7 @@ def inpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -257,7 +259,9 @@ def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
@@ -278,7 +282,7 @@ def inpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -289,6 +293,10 @@ def inpaint_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -311,8 +319,16 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=inpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -325,13 +341,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
inpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
inpaint_vae_info = f"VAE Path: {inpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=inpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -343,13 +369,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -362,15 +388,24 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -500,7 +535,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -545,7 +581,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
ondemand,
],
outputs=[inpaint_gallery, std_output, inpaint_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),

View File

@@ -31,8 +31,16 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
train_lora_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
train_lora_model_info = (
f"Custom Model Path: {train_lora_model_info}"
)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=train_lora_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -43,22 +51,33 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models "
"dropdown on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
# janky fix for overflowing text
train_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
train_lora_info = f"LoRA Path: {train_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA weights to initialize weights",
info=train_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
@@ -74,7 +93,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
@@ -215,7 +234,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
),
],
outputs=[std_output],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
prompt_submit = prompt.submit(**kwargs)

View File

@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
path = (
f"https://civitai.com/api/v1/models?limit="
f"{num_of_models}&types=Checkpoint"
)
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
model_number = gr.Slider(
1,
100,
value=10,
@@ -111,9 +114,9 @@ with gr.Blocks() as model_web:
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
def get_model_list(model_source, model_number):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
hf_model_list = get_hf_list(model_number)
models = []
for model in hf_model_list:
# TODO: add model info
@@ -124,7 +127,7 @@ with gr.Blocks() as model_web:
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
civit_model_list = get_civit_list(model_number)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
@@ -148,7 +151,7 @@ with gr.Blocks() as model_web:
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
inputs=[model_source, model_number],
outputs=[
hf_models,
civit_models,

View File

@@ -29,7 +29,6 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -92,7 +91,8 @@ def outpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -211,6 +211,7 @@ def outpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -264,7 +265,9 @@ def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
@@ -287,7 +290,7 @@ def outpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -298,6 +301,10 @@ def outpaint_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -320,8 +327,16 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=outpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -334,13 +349,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
outpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
outpaint_vae_info = f"VAE Path: {outpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=outpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -352,13 +377,13 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -368,15 +393,24 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -528,7 +562,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -573,7 +608,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
ondemand,
],
outputs=[outpaint_gallery, std_output, outpaint_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),

View File

@@ -9,9 +9,6 @@ from apps.stable_diffusion.src.utils import (
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.gradio_configs import (
gradio_tmp_galleries_folder,
)
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
@@ -41,14 +38,14 @@ def output_subdirs() -> list[str]:
)
]
# It is less confusing to always including the subdir that will take any images generated
# today even if it doesn't exist yet
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that that the date named ones we probably created in this or
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
# after.
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
@@ -63,26 +60,14 @@ def output_subdirs() -> list[str]:
return result_paths
# clear zero length temporary files that gradio 3.22.0 buggily creates
# TODO: remove once gradio is upgraded to or past 3.32.0
def clear_zero_length_temps():
zero_length_temps = [
os.path.join(root, file)
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
for file in files
if os.path.getsize(os.path.join(root, file)) == 0
]
for file in zero_length_temps:
os.remove(file)
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
@@ -105,7 +90,6 @@ with gr.Blocks() as outputgallery_web:
visible=False,
show_label=True,
).style(columns=4)
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
with gr.Column(scale=4):
with gr.Box():
@@ -179,7 +163,6 @@ with gr.Blocks() as outputgallery_web:
# --- Event handlers
def on_clear_gallery():
clear_zero_length_temps()
return [
gr.Gallery.update(
value=[],
@@ -210,16 +193,20 @@ with gr.Blocks() as outputgallery_web:
]
def on_refresh(current_subdir: str) -> list:
# get an up to date subdirectory list
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most recent valid one
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown.update(
@@ -238,18 +225,22 @@ with gr.Blocks() as outputgallery_web:
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab has even been selected
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as new images only go there
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
clear_zero_length_temps()
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
@@ -264,11 +255,13 @@ with gr.Blocks() as outputgallery_web:
),
]
else:
# otherwise change nothing, (only untyped gradio gr.update() does this)
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for the current subdirectory
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
@@ -286,7 +279,8 @@ with gr.Blocks() as outputgallery_web:
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether an image is selected
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
@@ -295,14 +289,16 @@ with gr.Blocks() as outputgallery_web:
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh to populate
# the subdirectory select box and the images from the most recent subdirectory.
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls' definitions
# as when you refresh the browser you always get what was *initially* set, which
# won't include any new subdirectories or images that might have created since
# the application was started. Doing it this way means a browser refresh/reload
# always gets the most up to date data.
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths):
if len(subdir_paths) == 0:
return on_refresh("")
@@ -316,11 +312,11 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
# things set with .style, nor the elem_classes kwarg so we have to directly set
# things up via JavaScript if we want the client to take notice of any of our
# changes to the number of columns after it decides to put them back to the
# original number when we change something
# Unfortunately as of gradio 3.22.0 gr.update against Galleries
# doesn't support things set with .style, nor the elem_classes kwarg, so
# we have to directly set things up via JavaScript if we want the client
# to take notice of our changes to the number of columns after it
# decides to put them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
@@ -337,32 +333,36 @@ with gr.Blocks() as outputgallery_web:
# --- Wire handlers up to the actions
# - Many actions reset the number of columns shown in the gallery on the browser end,
# so we have to set them back to what we think they should be after the initial
# action.
# - None of the actions on this tab trigger inference, and we want the user to be able
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
# behind inference jobs would mean the UI can't fully respond until the inference tasks
# complete, hence queue=False on all of these.
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real timeout length for the
# number of columns to actually be applied. Not really sure why, maybe something has
# to finish animating?
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the gallery avoids current
# images being shown replacing piecemeal and prevents weirdness and errors if the user
# selects an image during the replacement phase.
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,

View File

@@ -7,12 +7,17 @@ from transformers import (
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
start_message = (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
)
def user(message, history):
@@ -25,7 +30,11 @@ sharded_model = 0
vicuna_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
start_message_vicuna = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
)
past_key_values = None
@@ -35,23 +44,27 @@ def chat(curr_system_message, history, model, device, precision):
global past_key_values
global vicuna_model
if "vicuna" in model:
from apps.language_models.src.pipelines.vicuna_pipeline import (
Vicuna,
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
vicuna_model = Vicuna(
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
[
@@ -61,16 +74,11 @@ def chat(curr_system_message, history, model, device, precision):
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)
partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
for partial_text in vicuna_model.generate(prompt):
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence
return history
# else Model is StableLM
@@ -85,7 +93,8 @@ def chat(curr_system_message, history, model, device, precision):
"StableLM"
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the current system message and conversation history
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
@@ -105,7 +114,8 @@ def chat(curr_system_message, history, model, device, precision):
# print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
@@ -120,10 +130,12 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -134,8 +146,10 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
)
precision = gr.Radio(
label="Precision",
value="fp32",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],

View File

@@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils import (
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
@@ -86,7 +87,8 @@ def txt2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -137,6 +139,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
@@ -193,6 +196,7 @@ def txt2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
@@ -235,7 +239,9 @@ def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
@@ -262,6 +268,10 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -286,8 +296,16 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=t2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -298,13 +316,21 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
)
# janky fix for overflowing text
t2i_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_vae_info = f"VAE Path: {t2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -325,26 +351,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -475,7 +510,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -517,7 +553,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
ondemand,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
@@ -550,6 +586,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_png_info_img,
@@ -563,5 +602,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)

View File

@@ -88,7 +88,8 @@ def upscaler_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -202,6 +203,7 @@ def upscaler_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
@@ -225,10 +227,22 @@ def upscaler_inf(
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={steps}, "
f"noise_level={noise_level}, "
f"guidance_scale={guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={height}x{width}, "
f"batch_count={batch_count}, "
f"batch_size={batch_size}, "
f"max_length={args.max_length}"
)
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
@@ -269,7 +283,9 @@ def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
@@ -299,6 +315,9 @@ def upscaler_api(
lora_hf_id="",
ondemand=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -321,8 +340,16 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=upscaler_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -335,13 +362,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
upscaler_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
upscaler_vae_info = f"VAE Path: {upscaler_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=upscaler_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -353,13 +390,13 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -369,15 +406,24 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -508,7 +554,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="gallery",
).style(columns=[2], object_fit="contain")
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -550,7 +597,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
ondemand,
],
outputs=[upscaler_gallery, std_output, upscaler_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),

View File

@@ -39,8 +39,16 @@ scheduler_list_cpu_only = [
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"DPMSolverMultistep++",
"DPMSolverMultistepKarras",
"DPMSolverMultistepKarras++",
"EulerDiscrete",
"EulerAncestralDiscrete",
"DEISMultistep",
"KDPM2AncestralDiscrete",
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
@@ -50,6 +58,7 @@ predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"xzuyn/PhotoMerge",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
@@ -58,6 +67,7 @@ predefined_models = [
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"xzuyn/PhotoMerge-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
@@ -79,7 +89,8 @@ def create_custom_models_folders():
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)

View File

@@ -1,60 +1,54 @@
import os
import shutil
import tempfile
import gradio
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
# clear all gradio tmp files created by generation galleries
print(
"Clearing gradio temporary image files from a prior run. This may take some time..."
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
image_files = [
filename
for filename in os.listdir(gradio_tmp_imgs_folder)
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
for filename in image_files:
os.remove(gradio_tmp_imgs_folder + filename)
print(
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
else:
print("no generation temporary files to clear")
# Clear all gradio tmp files created by output galleries
if os.path.exists(gradio_tmp_galleries_folder):
cleanup_start = time()
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
print("no output gallery temporary files to clear")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")

View File

@@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def parse_csv(image_filename: str):
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
# headers, and then match up the return list for each row with our guess at which column format
# the file is using.
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename. We then exclude the filename from the output, hence the -1's.
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
matches = [
humanize(row)
for row in csv.reader(open(csv_filename, "r", newline=""))
if row
and humanizable(row)
and os.path.basename(image_filename) in row[-1]
]
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -50,7 +50,22 @@ PARAMS_FORMATS = {
},
}
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
@@ -97,19 +112,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we use the longest. Then we swap keys in the
# metadata that match keys in the format for the friendlier name that we
# have set in the format value
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_LONGEST
format = PARAMS_FORMAT_CURRENT
return {
format[key]: value
for (key, value) in metadata.items()
if key in format.keys()
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(
get_custom_model_pathfile(file + ".safetensors", folder)
).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in predefined_models:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
@@ -74,40 +150,21 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
@@ -115,12 +172,24 @@ def import_png_metadata(
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
@@ -149,4 +218,7 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -14,4 +14,4 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"

View File

@@ -16,7 +16,7 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
diffusers
scipy
ftfy
gradio==3.34.0
@@ -29,7 +29,11 @@ pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
py-cpuinfo
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# low precision vicuna
brevitas @ git+https://github.com/Xilinx/brevitas.git@llm

243
rest_api_tests/api_test.py Normal file
View File

@@ -0,0 +1,243 @@
import requests
from PIL import Image
import base64
from io import BytesIO
def upscaler_test():
# Define values here
prompt = ""
negative_prompt = ""
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/upscaler"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
def img2img_test():
# Define values here
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
denoising_strength = 0.75
cfg_scale = 7
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/img2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"init_images": init_images,
"height": height,
"width": width,
"steps": steps,
"denoising_strength": denoising_strength,
"cfg_scale": cfg_scale,
"seed": seed,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
# NOTE Uncomment below to save the picture
# print("Extracting response object")
# response_obj = res.json()
# img_b64 = response_obj.get("images", [False])[0] or response_obj.get(
# "image"
# )
# img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", ""))
# im_file = BytesIO(img_b2)
# response_img = Image.open(im_file)
# print("Saving Response Image to: response_img")
# response_img.save(r"rest_api_tests/response_img.png")
def inpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
noise_level = 10
cfg_scale = 7
is_full_res = False
full_res_padding = 32
image_path = r"./rest_api_tests/dog.png"
img_file = open(image_path, "rb")
image = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
img_file = open(image_path, "rb")
mask = (
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
)
url = "http://127.0.0.1:8080/sdapi/v1/inpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": image,
"mask": mask,
"height": height,
"width": width,
"steps": steps,
"noise_level": noise_level,
"cfg_scale": cfg_scale,
"seed": seed,
"is_full_res": is_full_res,
"full_res_padding": full_res_padding,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Inpainting] response from server was : {res.status_code}")
def outpainting_test():
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
color_variation = 0.2
noise_q = 0.2
directions = ["up", "down", "right", "left"]
pixels = 32
mask_blur = 64
image_path = r"./rest_api_tests/dog.png"
# Converting Image to Base64
img_file = open(image_path, "rb")
init_images = [
"data:image/png;base64," + base64.b64encode(img_file.read()).decode()
]
url = "http://127.0.0.1:8080/sdapi/v1/outpaint"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
"color_variation": color_variation,
"noise_q": noise_q,
"directions": directions,
"pixels": pixels,
"mask_blur": mask_blur,
"init_images": init_images,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Outpaint] response from server was : {res.status_code}")
def txt2img_test():
prompt = "Paint a rabbit in a top hate"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
height = 512
width = 512
steps = 50
cfg_scale = 7
url = "http://127.0.0.1:8080/sdapi/v1/txt2img"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"height": height,
"width": width,
"steps": steps,
"cfg_scale": cfg_scale,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[txt2img] response from server was : {res.status_code}")
if __name__ == "__main__":
txt2img_test()
img2img_test()
upscaler_test()
inpainting_test()
outpainting_test()

BIN
rest_api_tests/dog.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@@ -39,7 +39,7 @@ setup(
install_requires=[
"numpy",
"PyYAML",
"torch-mlir>=20221021.633",
"torch-mlir==20230620.875",
]
+ backend_deps,
)

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir==20230620.875 torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html

View File

@@ -27,6 +27,11 @@ PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; prin
echo "Python: $PYTHON"
echo "Python version: $PYTHON_VERSION_X_Y"
if [ "$PYTHON_VERSION_X_Y" != "3.11" ]; then
echo "Error: Python version 3.11 is required."
exit 1
fi
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
@@ -83,7 +88,7 @@ if [ "$torch_mlir_bin" = true ]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
$PYTHON -m pip install --pre torch-mlir==20230620.875 -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
@@ -154,3 +159,5 @@ if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
echo "${Green}Before running examples activate venv with:"
echo " ${Green}source $VENV_DIR/bin/activate"
fi
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm

View File

@@ -0,0 +1,28 @@
import importlib
import logging
from torch._dynamo import register_backend
log = logging.getLogger(__name__)
@register_backend
def shark(model, inputs, *, options):
try:
from shark.dynamo_backend.utils import SharkBackend
except ImportError:
log.exception(
"Unable to import SHARK - High Performance Machine Learning Distribution"
"Please install the right version of SHARK that matches the PyTorch version being used. "
"Refer to https://github.com/nod-ai/SHARK/ for details."
)
raise
return SharkBackend(model, inputs, options)
def has_shark():
try:
importlib.import_module("shark")
return True
except ImportError:
return False

View File

@@ -0,0 +1,154 @@
import functools
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
from torch.func import functionalize
import io
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
class SharkBackend:
def __init__(
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
):
self.fx_g = fx_g
self.inputs = inputs
self.shark_module = None
self.device: str = options.get("device", "cpu")
self.was_unwrapped: bool = False
self.none_indices: list = []
self._modify_fx_g()
self.compile()
def _modify_fx_g(self):
self.none_indices = _remove_nones(self.fx_g)
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
def compile(self):
gm = make_fx(
functionalize(self.fx_g),
decomposition_table=default_decompositions(),
)(*self.inputs)
gm.graph.set_codegen(torch.fx.graph.CodeGen())
gm.recompile()
strip_overloads(gm)
ts_g = torch.jit.script(gm)
mlir_module = torch_mlir.compile(
ts_g, self.inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=self.device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
self.shark_module = shark_module
def __call__(self, *inputs):
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
np_outs = self.shark_module("forward", np_inputs)
if self.was_unwrapped:
np_outs = [
np_outs,
]
if not isinstance(np_outs, list):
res = torch.from_numpy(np_outs)
return res
result = [torch.from_numpy(x) for x in np_outs]
for r_in in self.none_indices:
result.insert(r_in, None)
result = tuple(result)
return result

View File

@@ -1,70 +1,25 @@
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler
import shark
import warnings, logging
warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR
def foo(x, a):
if x.shape[0] > 3:
return x + a
else:
return x + 3
torchdynamo.reset()
shark_options = {"device": "cpu"}
compiled = torch.compile(foo, backend="shark", options=shark_options)
input = torch.ones(4)
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(t):
return 2 * t
x = compiled(input, input)
example_input = torch.rand((2, 3))
x = foo(example_input)
print(x)
input = torch.ones(3)
torchdynamo.reset()
x = compiled(input, input)
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(a, b):
x = a / (a + 1)
if b.sum() < 0:
b = b * -1
return x * b
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def foo(a):
for i in range(10):
a += 1.0
return a
print(foo(torch.rand((1, 2))))
torchdynamo.reset()
@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def test_unsupported_types(t, y):
return t, 2 * y
str_input = "hello"
tensor_input = torch.randn(2)
print(test_unsupported_types(str_input, tensor_input))
print(x)

View File

@@ -0,0 +1,72 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from shark.shark_compile import shark_compile_through_fx
from MEGABYTE_pytorch import MEGABYTE
import os
class MegaModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = MEGABYTE(
num_tokens=16000, # number of tokens
dim=(
512,
256,
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
max_seq_len=(
1024,
4,
), # sequence length for global and then local. this can be more than 2
depth=(
6,
4,
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
dim_head=64, # dimension per head
heads=8, # number of attention heads
flash_attn=True, # use flash attention
)
def forward(self, input):
return self.model(input)
megaModel = MegaModel()
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = shark_compile_through_fx(
model=megaModel,
inputs=inputs,
extended_model_name="mega_shark",
is_f16=False,
f16_input_mask=None,
save_dir=os.getcwd(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device="cuda",
mlir_dialect="tm_tensor",
)
# logits = model(x)
def print_output_info(output, msg):
print("\n", msg)
print("\n\t", output.shape)
ans = shark_module("forward", inputs)
print_output_info(torch.from_numpy(ans), "SHARK's output")
ans = megaModel.forward(*inputs)
print_output_info(ans, "ORIGINAL Model's output")
# and sample from the logits accordingly
# or you can use the generate function
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,10 +63,11 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
@@ -81,10 +82,11 @@ def iree_target_map(device):
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
@@ -101,11 +103,13 @@ def check_device_drivers(device):
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["metal", "vulkan"]:
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +14,14 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
import os
import re
import tempfile
from pathlib import Path
# Get the iree-compile arguments given device.
@@ -38,17 +41,32 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -175,8 +193,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
vmfb_file.close()
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)
benchmark_cl = build_benchmark_args_non_tensor_input(
@@ -307,8 +327,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
@@ -316,6 +336,63 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
return ModuleCompiled, config
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
module,
device: str,
@@ -323,19 +400,58 @@ def get_iree_compiled_module(
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def export_iree_module_to_vmfb(

View File

@@ -16,6 +16,7 @@
import subprocess
import platform
from shark.parser import shark_args
def get_cpu_count():
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]
# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2
cpu_count = (
default
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]

View File

@@ -0,0 +1,121 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# All the iree_vulkan related functionalities go here.
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
metal_device_list = [
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
]
if len(metal_device_list) == 0:
raise ValueError("No device name found in device dump!")
if len(metal_device_list) > 1:
print("Following devices found:")
for i, dname in enumerate(metal_device_list):
print(f"{i}. {dname}")
print(f"Choosing device: {metal_device_list[device_num]}")
return metal_device_list[device_num]
def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"
def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
else:
triple = None
return triple
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-platform=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None
def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag
def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return

View File

@@ -136,7 +136,7 @@ def get_vendor(triple):
return "Intel"
if arch in ["turing", "ampere", "pascal"]:
return "NVIDIA"
if arch == "ardeno":
if arch == "adreno":
return "Qualcomm"
if arch == "cpu":
if product == "swiftshader":

View File

@@ -114,6 +114,11 @@ def get_vulkan_target_triple(device_name):
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
# Adreno Targets
elif all(x in device_name for x in ("Adreno", "740")):
triple = f"adreno-a740-{system_os}"
else:
triple = None
return triple

View File

@@ -119,5 +119,11 @@ parser.add_argument(
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)
shark_args, unknown = parser.parse_known_args()

99
shark/shark_compile.py Normal file
View File

@@ -0,0 +1,99 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
shark_module = None
if os.path.isfile(vmfb_path):
shark_module = SharkInference(
None,
device=device,
mlir_dialect=mlir_dialect,
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
if os.path.isfile(vmfb_path):
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
print(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
def shark_compile_through_fx(
model,
inputs,
extended_model_name,
is_f16=False,
f16_input_mask=None,
save_dir=tempfile.gettempdir(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device=None,
mlir_dialect="tm_tensor",
):
if generate_or_load_vmfb:
shark_module = load_vmfb(
extended_model_name=extended_model_name,
device=device,
mlir_dialect=mlir_dialect,
extra_args=extra_args,
)
if shark_module:
return (
shark_module,
None,
)
from shark.parser import shark_args
if "cuda" in device:
shark_args.enable_tf32 = True
(
mlir_module,
_,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=device,
mlir_dialect=mlir_dialect,
)
return (
compile_module(
shark_module,
extended_model_name,
generate_vmfb=generate_or_load_vmfb,
extra_args=extra_args,
),
mlir_module,
)

View File

@@ -60,12 +60,15 @@ def download_public_file(
else:
continue
destination_filename = os.path.join(destination_folder_name, blob_name)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
else:
destination_filename = os.path.join(
destination_folder_name, blob_name
)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
input_type_to_np_dtype = {

View File

@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
import sys
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
import torch_mlir
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
mlir_module = torch_mlir.compile(
fx_g, inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
return shark_module
def _make_single_op_gm(node, captured_val, compiled_graph):
"""Make a GraphModule that just executes the given node."""
g = torch.fx.Graph()
env = {}
inputs = []
for arg in node.args:
if arg and hasattr(arg, "name"):
env[arg.name] = g.placeholder(arg.name)
if isinstance(captured_val[arg.name], (list, tuple)):
for val in captured_val[arg.name]:
inputs.append(val)
else:
inputs.append(captured_val[arg.name])
call = g.node_copy(node, lambda n: env[n.name])
g.output(call)
g.lint()
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
compiled_module = shark_backend(single_node, inputs)
compiled_graph[node.name] = {
"module": compiled_module,
"inputs": [i for i in env],
"result": None,
}
return
def compiled_graph(gm: torch.fx.GraphModule, attr_info):
compiled_graph = {}
g = gm.graph
for node in g.nodes:
if node.op == "call_function":
if not (
node.target in [torch.ops.aten.empty]
or node.name.startswith("getitem")
):
_make_single_op_gm(node, attr_info, compiled_graph)
# Currently torch.aten.empty has an compilation issue, so running natively.
elif node.target in [torch.ops.aten.empty]:
compiled_graph[node.name] = {
"target": node.target,
"args": node.args,
"kwargs": node.kwargs,
"result": None,
}
# Get item is a simple case takes a tuple and return the tensor at a particular index.
elif node.name.startswith("getitem"):
compiled_graph[node.name] = {
"input": node.args[0].name,
"pos": node.args[1],
"result": None,
}
return compiled_graph
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env: Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target: str):
target_atoms = target.split(".")
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == "placeholder":
result = next(args_iter)
elif node.op == "get_attr":
result = fetch_attr(node.target)
elif node.op == "call_function":
result = node.target(
*load_arg(node.args), **load_arg(node.kwargs)
)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return env
# return load_arg(self.graph.result)
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
input = (torch.randn(1, 3, 224, 224),)
print(resnet18(input[0]))
fx_graph = import_with_fx(resnet18, input, mlir_type="fx")
shape_prop = ShapeProp(fx_graph)
x = shape_prop.propagate(input[0])
shark_graph = compiled_graph(fx_graph, x)
for key in shark_graph:
if key.startswith("getitem"):
input_val = shark_graph[key]["input"]
pos = shark_graph[key]["pos"]
if input_val not in shark_graph:
shark_graph[key]["result"] = x[input_val][pos].detach()
else:
shark_graph[key]["result"] = shark_graph[input_val]["result"][
pos
].detach()
elif key.startswith("empty"):
operator = shark_graph[key]["target"]
args = shark_graph[key]["args"]
kwargs = shark_graph[key]["kwargs"]
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
else:
input_val = shark_graph[key]["inputs"]
input_tensors = []
for input in input_val:
if input not in shark_graph:
input_tensors.append(x[input].detach())
else:
input_tensors.append(shark_graph[input]["result"])
val = shark_graph[key]["module"]("forward", input_tensors)
if isinstance(val, (tuple, list)):
list_val = []
for v in val:
list_val.append(torch.from_numpy(v))
shark_graph[key]["result"] = list_val
else:
shark_graph[key]["result"] = torch.from_numpy(val)
print(shark_graph)

View File

@@ -1,5 +1,8 @@
import re
import json
from collections import OrderedDict
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
class GenerateConfigFile:
@@ -7,7 +10,9 @@ class GenerateConfigFile:
self,
model,
num_sharding_stages: int,
sharding_stages_id: list[str] = None,
sharding_stages_id: list[str],
model_input=None,
config_file_path="model_config.json",
):
self.model = model
self.num_sharding_stages = num_sharding_stages
@@ -15,8 +20,67 @@ class GenerateConfigFile:
assert self.num_sharding_stages == len(
self.sharding_stages_id
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
def generate_json(self):
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
f16_model=False,
torch_mlir_tracing=False,
):
graph_for_compilation = self.model
if fx_tracing_required:
graph_for_compilation = import_with_fx(
self.model,
self.model_input,
is_f16=f16_model,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
graph_for_compilation,
(self.model_input),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=torch_mlir_tracing,
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
compiled_module_str = str(
compile_str(
str(module),
target_backends=[backend],
extra_args=[
"--compile-to=flow",
"--mlir-elide-elementsattrs-if-larger=4",
],
)
)
substring_start_idx = [
m.start()
for m in re.finditer("flow.dispatch @", compiled_module_str)
]
dispatch_list = dict()
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
# can occur in a model
for dispatch_no, substring_idx in enumerate(substring_start_idx):
dispatch_idx = (
compiled_module_str[substring_idx:]
.split(":")[0]
.split("@")[-1]
)
key = "dispatch_no_" + str(dispatch_no)
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
dispatch_list[key]["dispatch_id"] = dispatch_idx
self.generate_json(dispatch_list)
def split_into_layers(self):
model_dictionary = dict()
for name, m in self.model.named_modules():
@@ -34,5 +98,8 @@ class GenerateConfigFile:
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict
with open("model_config.json", "w") as outfile:
json.dump(model_dictionary, outfile)
self.generate_json(model_dictionary)
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)

View File

@@ -353,7 +353,7 @@ def add_upcast(fx_g):
fx_g.graph.lint()
def transform_fx(fx_g):
def transform_fx(fx_g, quantized=False):
import torch
kwargs_dict = {
@@ -366,10 +366,24 @@ def transform_fx(fx_g):
}
for node in fx_g.graph.nodes:
if node.op == "call_function":
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
if quantized:
continue
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
]:
if node.kwargs.get("dtype") == torch.float32:
node.kwargs = kwargs_dict
@@ -426,17 +440,6 @@ def transform_fx(fx_g):
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# aten.empty should be filled with zeros.
if node.target in [torch.ops.aten.empty]:
with fx_g.graph.inserting_after(node):
new_node = fx_g.graph.call_function(
torch.ops.aten.zero_,
args=(node,),
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
# Required for cuda debugging.
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
@@ -485,6 +488,7 @@ def flatten_training_input(inputs):
return tuple(flattened_input)
# TODO: get rid of is_f16 by using precision
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(
model,
@@ -499,10 +503,28 @@ def import_with_fx(
mlir_type="linalg",
is_dynamic=False,
tracing_required=False,
precision="fp32",
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -510,24 +532,97 @@ def import_with_fx(
golden_values = model(*inputs)
except:
golden_values = None
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
# TODO: Control the decompositions.
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
]
),
)(*inputs)
decomps_list = [
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
if precision in ["int4", "int8"]:
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
)
with export_context_manager(model, export_class):
fx_g = brevitas_make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
transform_fx(fx_g, quantized=True)
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.brevitas.matmul_rhs_group_quant,
)
fx_g.recompile()
removed_none_indexes = _remove_nones(fx_g)
was_unwrapped = _unwrap_single_tuple_return(fx_g)
else:
fx_g = make_fx(
model,
decomposition_table=get_decompositions(decomps_list),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
@@ -552,6 +647,9 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)

View File

@@ -48,6 +48,8 @@ class SharkInference:
Refer to {https://mlir.llvm.org/docs/Dialects/}
is_benchmark: bool
Whether this SharkInference module should be benchmark-enabled.
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.
Methods
-------
@@ -70,6 +72,7 @@ class SharkInference:
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
@@ -88,6 +91,7 @@ class SharkInference:
)
self.shark_runner = None
self.mmap = mmap
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
@@ -201,12 +205,14 @@ class SharkInference:
compile_vmfb=False,
extra_args=extra_args,
)
(
self.shark_runner.iree_compilation_module,
self.shark_runner.iree_config,
) = load_flatbuffer(
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
return

View File

@@ -85,16 +85,17 @@ class SharkRunner:
if compile_vmfb == True:
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(
params = get_iree_compiled_module(
self.mlir_module,
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(

View File

@@ -1,11 +0,0 @@
1. Install torchdynamo
- `git clone https://github.com/pytorch/torchdynamo.git`
- `cd torchdynamo`
- `python -m pip install -r requirements.txt`
- `python setup.py develop`
2. Install functorch
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
3. Run examples.
- `python shark/examples/shark_dynamo/basic_examples.py`

View File

@@ -1,163 +0,0 @@
import functools
import time
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
)
def timeit(*, append_time_to: Optional[List] = None):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time_ns()
result = func(*args, **kwargs)
end_time = time.time_ns()
if append_time_to is not None:
append_time_to.append(end_time - start_time)
return result
return wrapper
return decorator
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
def compiler(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
"""Compile GraphModule using torch-mlir + SHARK."""
if verbose:
print("Compiling graph...")
if _returns_nothing(fx_graph):
return fx_graph
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
fx_graph = make_fx(
fx_graph, decomposition_table=default_decompositions()
)(*example_inputs)
strip_overloads(fx_graph)
if verbose:
print("torch.fx graph:")
print(fx_graph.graph)
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
ts_graph = ts_compiler(fx_graph, example_inputs)
if verbose:
torch_mlir_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.TORCH,
)
print("\n\ntorch-mlir backend contract graph:")
print(torch_mlir_module)
linalg_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
import io
bytecode_stream = io.BytesIO()
linalg_module.operation.write_bytecode(bytecode_stream)
mlir_module = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module, mlir_dialect="linalg", device=device
)
shark_module.compile()
def forward(*inputs):
result = shark_module("forward", inputs)
result = tuple() if result is None else result
return (result,) if was_unwrapped else result
return forward
return compiler
def check_results(compiled_results, eager_results):
for compiled_result, eager_result in zip(compiled_results, eager_results):
if not torch.allclose(
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
):
print("Compiled result does not match eager result")
return
print("Compiled result matches eager result!")
def print_time_stats(times):
times_tensor = torch.tensor(times)
def quantile_ms(q):
return torch.quantile(times_tensor.to(float), q).item() / 1e6
print(f"Median: {quantile_ms(0.5)} ms")
print(f"10%ile: {quantile_ms(0.1)} ms")
print(f"90%ile: {quantile_ms(0.9)} ms")
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
print()

View File

@@ -25,18 +25,18 @@ google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
1 resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
25 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False https://github.com/nod-ai/SHARK/issues/390 macos
26 microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default None False True False
27 google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False https://github.com/nod-ai/SHARK/issues/344 macos
28 mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc False True True True https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
29 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True True https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487 macos
30 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False True False False macos
31 resnet18 linalg torch 1e-2 1e-3 default None True True False macos
32 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
33 resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True False True
34 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False macos
35 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col False True False False macos
36 efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True macos
38 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True https://github.com/nod-ai/SHARK/issues/1487 macos
39 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc False True True True https://github.com/nod-ai/SHARK/issues/1487 macos
40 efficientnet_b0 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 stablehlo tf 1e-2 1e-3 default None True False False macos

View File

@@ -12,8 +12,8 @@ from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 30
MAX_NEW_TOKENS = 20
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60
def create_module(model_name, tokenizer, device):
@@ -110,13 +110,13 @@ if __name__ == "__main__":
"facebook/" + OPT_MODEL, use_fast=False
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-sync.vmfb"
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-sync")
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-sync")
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
opt_shark_module.load_module(vmfb_path)
while True:
try:

View File

@@ -24,4 +24,5 @@ bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
24 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
25 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
26 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
27 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
28 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -