Compare commits

..

38 Commits

Author SHA1 Message Date
PhaneeshB
eab2194ca1 fix time calc for sharded 2023-12-06 01:20:47 +05:30
PhaneeshB
93f583f0be fix device_idx for non-layer vmfbs 2023-12-06 01:20:47 +05:30
PhaneeshB
e5ed167f03 mmap shards + disable sharing of device arrays across devices 2023-12-06 01:20:47 +05:30
Elias Joseph
051ba5de63 improved sharded performance and fixed issue with lmhead on rocm 2023-12-06 01:20:47 +05:30
Ean Garvey
6384780d16 Fixes to llama2 cpu compilation and studio UI, schedulers (#2013)
* Fix some issues with defaults

Fixes to llama2 cpu compilation (turns off data tiling for old argmax
mode)

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
2023-12-05 11:19:19 -05:00
gpetters94
db0c53ae59 Fix zoedepth (#2010) 2023-12-05 04:31:50 -05:00
Ean Garvey
ce9ce3a7c8 (SD) Fix schedulers and multi-controlnet. (#2006)
* (SD) Fixes schedulers if recieving noise preds as numpy arrays

* Fix schedulers and stencil name

* Multicontrolnet fixes
2023-12-05 03:29:18 -06:00
Ean Garvey
d72da3801f (Studio) Update gradio and multicontrolnet UI. (#2001)
* (Studio) Update gradio and multicontrolnet UI.

* Fixes for outputgallery, exe build

* Fix image return types.

* Update Gradio to 4.7.1

* Fix send buttons and hiresfix

* Various bugfixes and SDXL additions.

* More UI fixes and txt2img_sdxl presets.

*enable SDXL-Turbo and custom models, custom VAE for sdxl

* img2img ui tweaks
2023-12-04 12:37:51 -06:00
Eliasj42
9c50edc664 fixed functionality of sharded vicuna/llama2 (#1982)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-12-04 09:11:52 -08:00
Abhishek Varma
a1b7110550 [SDXL] Add SDXL pipeline to SHARK (#1941)
* [SDXL] Add SDXL pipeline to SHARK

-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-12-02 03:15:15 -06:00
gpetters94
ff15fd74f6 Add multicontrolnet (#1958) 2023-12-01 13:51:20 -06:00
gpetters94
552b2c3ee3 Add controlmode (#1957) 2023-12-01 13:04:47 -06:00
Ean Garvey
795fc33001 Update default compilation flags for data tiling. (#2000)
* Update default CPU compilation flags.

c5a6cdc8dd

52eb7e9b82

tweak CPU iree-compile flags to match upstream changes.

* Add an option for data tiling on SD models.
2023-11-30 17:05:37 -06:00
gpetters94
2910841fe6 Fix an importer issue on Linux (#1986) 2023-11-30 10:50:33 -06:00
Vivek Khandelwal
396a054856 Fix Sharded Falcon-180b 2023-11-30 21:51:57 +05:30
Vivek Khandelwal
5c66948d4f Fix unsharded Falcon pipeline 2023-11-30 21:51:57 +05:30
Ean Garvey
ed3dda94c0 Cleanup xfails in pytest suite. (#1995) 2023-11-29 23:16:15 -06:00
Quinn Dawkins
d31d28b082 [SD] Add flag to collapse reduction dims pre dispatch formation (#1999) 2023-11-30 00:09:17 -05:00
Evan Ruttenberg
78c607e1d3 Fix typo in default_rocm_arch (#1998) 2023-11-29 20:40:56 -05:00
Vivek Khandelwal
666e601dd9 Remove sharding support for non-180B falcon variants 2023-11-27 13:45:13 +05:30
Vivek Khandelwal
ca58908e5b Add Falcon-GPTQ Support for 2-way sharding 2023-11-27 13:45:13 +05:30
Jakub Kuderski
1f5b39f56e [vicuna.py] Add option to enable tracing (#1993)
This makes the program wait for tracy profiler to connect before exiting
and flush profiling data after each token.

I don't know how to select the tracy iree-runtime variant
programatically -- instead, print an error and exit.
2023-11-24 12:25:03 -08:00
Jakub Kuderski
2da31c4109 [vicuna.py] Rework benchmark statistics calculation (#1992)
- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
2023-11-23 12:04:03 -05:00
Ean Garvey
da50a16242 Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989) 2023-11-23 01:01:41 -06:00
Stefan Kapusniak
ce38d49f05 Add .mlir to startup shark_tmp cleanup (#1991)
* Add .mlir to the fiiles that are deleted from `./shark_tmp` when studio
is started.
* refactor/rename existing gradio temp file cleanup on startup to be
consistent with a general `./shark_tmp` cleanup
2023-11-22 14:34:28 -06:00
PhaneeshB
2f780f0d38 quick fix rocm None device 2023-11-22 21:17:25 +05:30
Ean Garvey
d051c3a4a7 Use clean_device_info() by default and don't write .mlir to /tmp/ (#1984)
* Move clean_device_info to compile_utils

* Update compile_utils.py

* Fix .mlir writes for some user-level permissions

* Fix cases where full URI is given

* Fix conditionals.

* Fix device path handling in vulkan utils.
2023-11-20 13:10:31 -06:00
Ean Garvey
1b11c82c9d Small UI tweaks for chatbot, fix torchvision requirements (#1988)
- add torchvision to setup_venv.ps1 -- we need this for the torchvision::nms that is now a dependency of controlnet features.
- Don't have bad flashy orange updates when using the chatbot
- Don't limit the height of the chatbot -- there's mixed opinions and solutions around this one. I think the default (400) is just way too small and LLMs generate plenty enough to justify matching the output.
2023-11-21 00:09:10 +05:30
gpetters94
80a33d427f Save intermediate values of controlnet (#1981) 2023-11-17 19:05:41 -05:00
Stefan Kapusniak
4125a26294 API/Docs: Fix incorrect cors arguments listing (#1983)
* Replace `api_cors_origin` in the api/koboldcpp doc, with the correct
 `api_accept_origin`
2023-11-17 12:29:01 -06:00
Ean Garvey
905d0103ff Revert "Re-enable SD tunings without matmuls. (#1976)" (#1979)
This reverts commit 70817bb50a.
2023-11-17 23:44:33 +05:30
Stefan Kapusniak
192b3b2c61 UI: Output galllery cleanups (#1959)
* Workaround gradio bug that causes the parameters frame to always show
scrollbars.
* Remove the original funky method of setting the number of image
columns in the gallery using _fn= javacript events. The version
of gradio we now have pinned allows doing this by setting the property
on the gallery directly and also doesn't keep resetting the columns on
other events being fired.
2023-11-15 22:20:42 -06:00
Stefan Kapusniak
8f9adc4a2a UI: Display top tag frequencies for selected LoRA (#1972)
* Adds a function to webui utils to read metadata from
.safetensors LoRA files. and do limiting parsing of the format written
out by the Kohya SS scripts (https://github.com/kohya-ss/sd-scripts)
to get tag frequency and trained model information.
* Adds a new common_ui_events.py file for gradio event handlers
needed for multiple UI tabs, and adds an event handler for binding to
the change event of the LoRA selection boxes, that outputs HTML
to display the LoRA tag frequency and model information.
* Adds an HTML gradio control to each of the SD tabs to show the
LoRA model name, and most frequently trained tags.
* Bind the change event of the LoRA selection box on each tab
to our new event handler, with the output set to the relevant HTML
control.
2023-11-15 22:19:54 -06:00
Ean Garvey
70817bb50a Re-enable SD tunings without matmuls. (#1976) 2023-11-15 20:42:53 -06:00
jinchen62
dd37c26d36 Update brevitas quant api (#1975) 2023-11-15 10:04:07 -08:00
PhaneeshB
a708879c6c fix iree version mismatch 2023-11-15 01:24:42 +05:30
Ean Garvey
bb1b49eb6f Add --no-index to setup_venv.sh runtime pip install. 2023-11-14 21:44:20 +05:30
Ean Garvey
f6d41affd9 (SHARK Studio) Add Turbine-based llm chatbot. (#1933)
* Dan shark studio (#1970)

* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>

* Fix formatting.

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-14 09:56:28 -06:00
63 changed files with 4851 additions and 1174 deletions

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
@@ -121,9 +121,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -144,10 +144,10 @@ jobs:
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'

View File

@@ -25,7 +25,7 @@ from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,

File diff suppressed because it is too large Load Diff

View File

@@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
return torch.tensor(new_hidden_states)
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
def forward(self, hidden_states, attention_mask):
output = self.model.forward(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
return (output[0], output[1][0], output[1][1])
class CompiledDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
new_hidden_states, pkv1, pkv2 = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),
tuple(
[
torch.tensor(pkv1),
torch.tensor(pkv2),
]
),
]
)
class EightDecoderLayer(torch.nn.Module):
class FourWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
@@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module):
outputs[-1][1],
)
)
if self.falcon_variant == "7b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "40b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
return result
class CompiledEightDecoderLayer(torch.nn.Module):
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
@@ -339,12 +170,14 @@ class CompiledEightDecoderLayer(torch.nn.Module):
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@@ -354,24 +187,12 @@ class CompiledEightDecoderLayer(torch.nn.Module):
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
@@ -383,196 +204,452 @@ class CompiledEightDecoderLayer(torch.nn.Module):
attention_mask,
),
)
del self.model
if self.falcon_variant == "7b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
return result
class TwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
elif self.falcon_variant == "40b":
result = (
torch.tensor(output[0]),
hidden_states = outputs[0]
new_pkvs.append(
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
(new_pkv200, new_pkv201),
(new_pkv210, new_pkv211),
(new_pkv220, new_pkv221),
(new_pkv230, new_pkv231),
(new_pkv240, new_pkv241),
(new_pkv250, new_pkv251),
(new_pkv260, new_pkv261),
(new_pkv270, new_pkv271),
(new_pkv280, new_pkv281),
(new_pkv290, new_pkv291),
(new_pkv300, new_pkv301),
(new_pkv310, new_pkv311),
(new_pkv320, new_pkv321),
(new_pkv330, new_pkv331),
(new_pkv340, new_pkv341),
(new_pkv350, new_pkv351),
(new_pkv360, new_pkv361),
(new_pkv370, new_pkv371),
(new_pkv380, new_pkv381),
(new_pkv390, new_pkv391),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
new_pkv200,
new_pkv201,
new_pkv210,
new_pkv211,
new_pkv220,
new_pkv221,
new_pkv230,
new_pkv231,
new_pkv240,
new_pkv241,
new_pkv250,
new_pkv251,
new_pkv260,
new_pkv261,
new_pkv270,
new_pkv271,
new_pkv280,
new_pkv281,
new_pkv290,
new_pkv291,
new_pkv300,
new_pkv301,
new_pkv310,
new_pkv311,
new_pkv320,
new_pkv321,
new_pkv330,
new_pkv331,
new_pkv340,
new_pkv341,
new_pkv350,
new_pkv351,
new_pkv360,
new_pkv361,
new_pkv370,
new_pkv371,
new_pkv380,
new_pkv381,
new_pkv390,
new_pkv391,
)
return result
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
(
torch.tensor(output[41]),
torch.tensor(output[42]),
),
(
torch.tensor(output[43]),
torch.tensor(output[44]),
),
(
torch.tensor(output[45]),
torch.tensor(output[46]),
),
(
torch.tensor(output[47]),
torch.tensor(output[48]),
),
(
torch.tensor(output[49]),
torch.tensor(output[50]),
),
(
torch.tensor(output[51]),
torch.tensor(output[52]),
),
(
torch.tensor(output[53]),
torch.tensor(output[54]),
),
(
torch.tensor(output[55]),
torch.tensor(output[56]),
),
(
torch.tensor(output[57]),
torch.tensor(output[58]),
),
(
torch.tensor(output[59]),
torch.tensor(output[60]),
),
(
torch.tensor(output[61]),
torch.tensor(output[62]),
),
(
torch.tensor(output[63]),
torch.tensor(output[64]),
),
(
torch.tensor(output[65]),
torch.tensor(output[66]),
),
(
torch.tensor(output[67]),
torch.tensor(output[68]),
),
(
torch.tensor(output[69]),
torch.tensor(output[70]),
),
(
torch.tensor(output[71]),
torch.tensor(output[72]),
),
(
torch.tensor(output[73]),
torch.tensor(output[74]),
),
(
torch.tensor(output[75]),
torch.tensor(output[76]),
),
(
torch.tensor(output[77]),
torch.tensor(output[78]),
),
(
torch.tensor(output[79]),
torch.tensor(output[80]),
),
)
return result

View File

@@ -5,7 +5,7 @@ from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicuna(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicuna(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -79,7 +81,9 @@ class SecondVicuna7B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -91,7 +95,7 @@ class SecondVicuna7B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -329,7 +333,9 @@ class SecondVicuna13B(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -341,7 +347,7 @@ class SecondVicuna13B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -627,7 +633,9 @@ class SecondVicuna70B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -639,7 +647,7 @@ class SecondVicuna70B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicunaGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicunaGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -78,7 +80,9 @@ class SecondVicuna7BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -90,7 +94,7 @@ class SecondVicuna7BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -327,7 +331,9 @@ class SecondVicuna13BGPU(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -339,7 +345,7 @@ class SecondVicuna13BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -625,7 +631,9 @@ class SecondVicuna70BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -637,7 +645,7 @@ class SecondVicuna70BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -1,4 +1,5 @@
import torch
import time
class FirstVicunaLayer(torch.nn.Module):
@@ -110,9 +111,11 @@ class LMHeadCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
hidden_states_sample = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
@@ -136,8 +139,9 @@ class VicunaNormCompiled(torch.nn.Module):
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = self.model("forward", (hidden_states,), send_to_host=True)
output = torch.tensor(output)
return output
@@ -158,8 +162,9 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = self.model("forward", (input_ids,), send_to_host=True)
output = torch.tensor(output)
return output
@@ -178,9 +183,10 @@ class CompiledVicunaLayer(torch.nn.Module):
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
@@ -188,11 +194,17 @@ class CompiledVicunaLayer(torch.nn.Module):
attention_mask,
position_ids,
),
send_to_host=True,
)
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
return (
output0,
@@ -202,11 +214,12 @@ class CompiledVicunaLayer(torch.nn.Module):
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
# pkv0 = past_key_value[0].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
"second_vicuna_forward",
(
@@ -216,11 +229,17 @@ class CompiledVicunaLayer(torch.nn.Module):
pkv0,
pkv1,
),
send_to_host=True,
)
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
return (
output0,

View File

@@ -6,10 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
FourWayShardingDecoderLayer,
TwoWayShardingDecoderLayer,
CompiledFourWayShardingDecoderLayer,
CompiledTwoWayShardingDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -94,6 +94,13 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
parser.add_argument(
"--num_shards",
type=int,
default=4,
choices=[2, 4],
help="Number of shards.",
)
class ShardedFalcon(SharkLLMBase):
@@ -122,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
if args.sharded and "180b" not in self.model_name:
raise ValueError("Sharding supported only for Falcon-180B")
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
@@ -131,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile(compressed=args.compressed)
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -146,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(
@@ -225,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
f16_input_mask = [True, True]
else:
raise ValueError("Unsupported layer: ", layer_id)
@@ -288,28 +296,16 @@ class ShardedFalcon(SharkLLMBase):
return shark_module, device_idx
def compile(self, compressed=False):
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
[1, 100, 14848], dtype=torch.float32
)
# Determine number of available devices
@@ -319,6 +315,10 @@ class ShardedFalcon(SharkLLMBase):
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if num_devices < 2:
raise ValueError(
"Cannot run Falcon-180B on a single ROCM device."
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
@@ -326,7 +326,9 @@ class ShardedFalcon(SharkLLMBase):
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
device_idx=(0 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -338,7 +340,9 @@ class ShardedFalcon(SharkLLMBase):
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
device_idx=(1 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
@@ -350,7 +354,9 @@ class ShardedFalcon(SharkLLMBase):
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
device_idx=(2 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
@@ -360,24 +366,21 @@ class ShardedFalcon(SharkLLMBase):
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = FourWayShardingDecoderLayer
compiled_class = CompiledFourWayShardingDecoderLayer
if args.num_shards == 2:
pytorch_class = TwoWayShardingDecoderLayer
compiled_class = CompiledTwoWayShardingDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
@@ -388,13 +391,13 @@ class ShardedFalcon(SharkLLMBase):
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
shark_module,
)
shark_layers.append(shark_layer_i)
@@ -668,20 +671,17 @@ class UnshardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):

View File

@@ -132,7 +132,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

View File

@@ -105,6 +105,7 @@ def main():
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
control_mode=args.control_mode,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -31,6 +31,7 @@ datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
@@ -75,6 +76,7 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
@@ -85,4 +87,4 @@ hiddenimports += [
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
hiddenimports += ["iree._runtime"]

View File

@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from collections import defaultdict
from pathlib import Path
import torch
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
get_stencil_model_id,
update_lora_weight,
)
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
# These shapes are parameter dependent.
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
)
def shark_compile_after_ir(
module_name,
device,
vmfb_path,
precision,
ir_path=None,
):
if ir_path:
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
module = SharkInference(
mlir_module=ir_path,
device=device,
mlir_dialect="tm_tensor",
)
print(f"Will get extra flag for {module_name} and precision = {precision}")
path = module.save_module(
vmfb_path.parent.absolute(),
vmfb_path.stem,
extra_args=get_opt_flags(module_name, precision=precision),
)
print(f"Saved {module_name} vmfb at {path}")
module.load_module(path)
return module
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
name_split = extended_model_name.split("_")
if "vae" in model_name:
name_split[5] = "fp32"
extended_model_name_for_vmfb = "_".join(name_split)
extended_model_name_for_mlir = "_".join(name_split[:-1])
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
device=device,
mlir_dialect="tm_tensor",
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=[])
return shark_module, None
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
if not mlir_path.exists():
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
download_public_file(
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
return (
shark_compile_after_ir(
model_name, device, vmfb_path, precision, mlir_path
),
None,
)
return None, None
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -86,13 +156,15 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
@@ -144,7 +216,7 @@ class SharkifyStableDiffusionModel:
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
@@ -175,14 +247,15 @@ class SharkifyStableDiffusionModel:
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
"stencil_adapter",
"stencil_adapter_512",
]
index = 0
for model in sub_model_list:
@@ -195,10 +268,19 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
if "stencil_adapter" in model:
stencil_names = []
for i, stencil in enumerate(self.stencils):
if stencil is not None:
cnet_config = model_config + stencil.split("_")[-1]
stencil_names.append(
get_extended_name(sub_model + cnet_config)
)
model_name[model] = stencil_names
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
@@ -342,6 +424,105 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["vae"], "vae", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
print(f"Loading default vae, with target {model_id}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
custom_vae = "/".join(
[
custom_vae.split("/")[-2].split("\\")[-1],
custom_vae.split("/")[-1],
]
)
print("Using hub to get custom vae")
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
if not self.custom_vae:
is_f16 = False
elif "16" in self.custom_vae:
is_f16 = True
else:
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -380,25 +561,54 @@ class SharkifyStableDiffusionModel:
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control1 * scale1,
control2 * scale2,
control3 * scale3,
control4 * scale4,
control5 * scale5,
control6 * scale6,
control7 * scale7,
control8 * scale8,
control9 * scale9,
control10 * scale10,
control11 * scale11,
control12 * scale12,
]
)
mb_res_samples = control13
mb_res_samples = control13 * scale13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
@@ -446,6 +656,19 @@ class SharkifyStableDiffusionModel:
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
@@ -462,17 +685,19 @@ class SharkifyStableDiffusionModel:
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
print(f"Importing ControlNet adapter from {stencil_id}")
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.in_channels = self.cnet.config.in_channels
self.train(False)
def forward(
@@ -481,6 +706,19 @@ class SharkifyStableDiffusionModel:
timestep,
text_embedding,
stencil_image_input,
acc1,
acc2,
acc3,
acc4,
acc5,
acc6,
acc7,
acc8,
acc9,
acc10,
acc11,
acc12,
acc13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
@@ -502,6 +740,20 @@ class SharkifyStableDiffusionModel:
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
) + (
acc1 + down_block_res_samples[0],
acc2 + down_block_res_samples[1],
acc3 + down_block_res_samples[2],
acc4 + down_block_res_samples[3],
acc5 + down_block_res_samples[4],
acc6 + down_block_res_samples[5],
acc7 + down_block_res_samples[6],
acc8 + down_block_res_samples[7],
acc9 + down_block_res_samples[8],
acc10 + down_block_res_samples[9],
acc11 + down_block_res_samples[10],
acc12 + down_block_res_samples[11],
acc13 + mid_block_res_sample,
)
scnet = StencilControlNetModel(
@@ -509,7 +761,23 @@ class SharkifyStableDiffusionModel:
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
if isinstance(ext_model_name, list):
for i in ext_model_name:
if stencil_id.split("_")[-1] in i:
desired_name = i
print(f"Multi-CN: compiling model {i}")
else:
continue
if desired_name:
ext_model_name = desired_name
else:
raise Exception(
f"Could not find extended configuration for {stencil_id}"
)
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
@@ -517,21 +785,15 @@ class SharkifyStableDiffusionModel:
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
input_mask = [True, True, True, True] + ([True] * 13)
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=ext_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -688,6 +950,101 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["unet"], "unet", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -735,6 +1092,78 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
if clip_index == 1:
extended_model_name = self.model_name["clip"]
model_name = "clip"
else:
extended_model_name = self.model_name["clip2"]
model_name = "clip2"
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
extended_model_name, f"clip", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -767,7 +1196,9 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -809,9 +1240,28 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]
@@ -820,7 +1270,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
@@ -831,7 +1281,7 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=model_id
)
except Exception as e:
print(e)
@@ -870,7 +1320,10 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
@@ -880,18 +1333,18 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
self.inputs["stencil_adapter"] = self.get_input_info_for(
base_models["stencil_adapter"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
check_compilation(compiled_stencil_adapter, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
return compiled_stencil_adapter
except Exception as e:
sys.exit(e)

View File

@@ -123,8 +123,11 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
args.hf_model_id, subfolder=subfolder
)
return tokenizer

View File

@@ -1,6 +1,9 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -158,8 +158,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
images,
resample_type,
control_mode,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):

View File

@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
import_mlir: bool,
use_lora: str,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
def load_controlnet(self):
if self.controlnet is not None:
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
@@ -111,8 +130,9 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
control_mode="Balanced", # Prompt, Balanced, or Controlnet
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -121,12 +141,18 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, name in enumerate(self.controlnet_names):
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -149,33 +175,93 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
# Multicontrolnet
width = latent_model_input_1.shape[2]
height = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
+ [
torch.zeros(
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
)
]
* 4
)
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
continue
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
else:
control = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
control_acc = control[13:]
control = control[:13]
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
dtype = latents.dtype
if control_mode == "Balanced":
control_scale = [
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
]
elif control_mode == "Prompt":
control_scale = [
torch.tensor(0.825**x, dtype=dtype)
for x in range(len(control))
]
elif control_mode == "Controlnet":
control_scale = [
torch.tensor(float(guidance_scale), dtype=dtype)
for _ in range(len(control))
]
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
@@ -197,6 +283,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -222,6 +321,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -245,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -272,14 +385,29 @@ class StencilPipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
control_mode,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
for i, stencil in enumerate(stencils):
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -327,7 +455,8 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
control_mode=control_mode,
stencil_hints=stencil_hints,
)
# Img latents -> PIL images

View File

@@ -18,7 +18,10 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)

View File

@@ -0,0 +1,220 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.is_fp32_vae = is_fp32_vae
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -20,7 +20,10 @@ from diffusers import (
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
@@ -33,6 +36,8 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -50,6 +55,7 @@ class StableDiffusionPipeline:
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
@@ -60,20 +66,23 @@ class StableDiffusionPipeline:
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_f32_vae: bool = False,
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
@@ -106,6 +115,34 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -159,6 +196,182 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -186,6 +399,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -298,6 +512,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -306,6 +522,96 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# return None
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
if isinstance(latents, np.ndarray):
latents = torch.tensor(latents)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=True,
)
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to("cpu")
noise_pred = torch.from_numpy(noise_pred).to("cpu")
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.detach().numpy()
noise_pred = noise_pred.detach().numpy()
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents, is_fp32_vae):
# latents are in unet dtype here so switch if we want to use fp32
if is_fp32_vae:
print("Casting latents to float32 for VAE")
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -338,7 +644,8 @@ class StableDiffusionPipeline:
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
@@ -355,6 +662,7 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
@@ -371,7 +679,8 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
)
@@ -386,6 +695,21 @@ class StableDiffusionPipeline:
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
is_fp32_vae,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
@@ -498,9 +822,10 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
return text_embeddings.numpy().astype(np.float16)
from typing import List, Optional, Union

View File

@@ -1,4 +1,7 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers

View File

@@ -1,4 +1,5 @@
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
@@ -15,9 +16,21 @@ from diffusers import (
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: Robust scheduler setup on pipeline creation -- if we don't
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.
batch_size = 2 if "xl" in model_id.lower() else 1
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerAncestralDiscrete"
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
schedulers["SharkEulerDiscrete"].compile(batch_size)
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
return schedulers

View File

@@ -0,0 +1,251 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
EulerAncestralDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
timestep_spacing,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
self.init_input_shape = None
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"eulera": {
"output": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"latent": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"sigma_from": torch.tensor(1).to(torch.float32),
"sigma_to": torch.tensor(1).to(torch.float32),
"noise": torch.randn(
batch_size, 4, args.height // 8, args.width // 8
),
},
}
example_latent = model_input["eulera"]["latent"]
example_output = model_input["eulera"]["output"]
example_noise = model_input["eulera"]["noise"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_noise = example_noise.half()
example_sigma = model_input["eulera"]["sigma"]
example_sigma_from = model_input["eulera"]["sigma_from"]
example_sigma_to = model_input["eulera"]["sigma_to"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
):
sigma_up = (
sigma_to**2
* (sigma_from**2 - sigma_to**2)
/ sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
dt = sigma_down - sigma
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
prev_sample = latent + derivative * dt
return prev_sample + noise * sigma_up
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
def _import(self):
scaling_model = ScalingModel()
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(
example_output,
example_latent,
example_sigma,
example_sigma_from,
example_sigma_to,
example_noise,
),
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
if args.import_mlir:
_import(self)
else:
try:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_a_step_"
+ self.config.prediction_type
+ args.precision,
iree_flags,
)
except:
print(
"failed to download model, falling back and using import_mlir"
)
args.import_mlir = True
_import(self)
def scale_model_input(self, sample, timestep):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(
self,
noise_pred,
timestep,
latent,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
step_inputs = []
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
step_inputs = [
noise_pred,
latent,
sigma,
sigma_from,
sigma_to,
noise,
]
# TODO: deal with dynamic inputs in turbine flow.
# update step index since we're done with the variable and will return with compiled module output.
self._step_index += 1
if noise_pred.shape[0] < self.batch_size:
for i in [0, 1, 5]:
try:
step_inputs[i] = torch.tensor(step_inputs[i])
except:
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
step_inputs[i] = torch.cat(
(step_inputs[i], step_inputs[i]), axis=0
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=True,
)
return self.step_model(
"forward",
tuple(step_inputs),
send_to_host=False,
)

View File

@@ -2,12 +2,9 @@ import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete",
steps_offset: int = 0,
):
super().__init__(
num_train_timesteps,
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
beta_schedule,
trained_betas,
prediction_type,
interpolation_type,
use_karras_sigmas,
sigma_min,
sigma_max,
timestep_spacing,
timestep_type,
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
def compile(self):
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
device = args.device.split(":", 1)[0].strip()
self.batch_size = batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
batch_size, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
class SchedulerStepEpsilonModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = latent - sigma_hat * noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepSampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma_hat, latent, dt):
pred_original_sample = noise_pred
derivative = (latent - pred_original_sample) / sigma_hat
return latent + derivative * dt
class SchedulerStepVPredictionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
pred_original_sample = noise_pred * (
-sigma / (sigma**2 + 1) ** 0.5
) + (latent / (sigma**2 + 1))
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
@@ -90,16 +123,22 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model, _ = compile_through_fx(
model=scaling_model,
inputs=(example_latent, example_sigma),
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
pred_type_model_dict = {
"epsilon": SchedulerStepEpsilonModel(),
"v_prediction": SchedulerStepVPredictionModel(),
"sample": SchedulerStepSampleModel(),
"original_sample": SchedulerStepSampleModel(),
}
step_model = pred_type_model_dict[self.config.prediction_type]
self.step_model, _ = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
+ args.precision,
extra_args=iree_flags,
)
@@ -109,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
else:
try:
step_model_type = (
"sample"
if "sample" in self.config.prediction_type
else self.config.prediction_type
)
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
@@ -116,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_step_" + args.precision,
"euler_step_" + step_model_type + args.precision,
iree_flags,
)
except:
@@ -138,15 +182,57 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
def step(
self,
noise_pred,
timestep,
latent,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: Optional[bool] = False,
):
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
gamma = (
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigma <= s_tmax
else 0.0
)
sigma_hat = sigma * (gamma + 1)
noise_pred = (
torch.from_numpy(noise_pred)
if isinstance(noise_pred, np.ndarray)
else noise_pred
)
if gamma > 0:
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
eps = noise * s_noise
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.config.prediction_type == "v_prediction":
sigma_hat = sigma
dt = self.sigmas[self.step_index + 1] - sigma_hat
return self.step_model(
"forward",
(
noise_pred,
sigma,
sigma_hat,
latent,
dt,
),

View File

@@ -8,6 +8,15 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -179,9 +188,95 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/sdxl-turbo": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {
"stencil_adapter": {
"latents": {
"shape": [
"1*batch_size",
@@ -208,6 +303,58 @@
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
},
"acc1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"acc5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"acc8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"stencil_unet": {
@@ -290,7 +437,59 @@
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"scale1": {
"shape": 1,
"dtype": "f32"
},
"scale2": {
"shape": 1,
"dtype": "f32"
},
"scale3": {
"shape": 1,
"dtype": "f32"
},
"scale4": {
"shape": 1,
"dtype": "f32"
},
"scale5": {
"shape": 1,
"dtype": "f32"
},
"scale6": {
"shape": 1,
"dtype": "f32"
},
"scale7": {
"shape": 1,
"dtype": "f32"
},
"scale8": {
"shape": 1,
"dtype": "f32"
},
"scale9": {
"shape": 1,
"dtype": "f32"
},
"scale10": {
"shape": 1,
"dtype": "f32"
},
"scale11": {
"shape": 1,
"dtype": "f32"
},
"scale12": {
"shape": 1,
"dtype": "f32"
},
"scale13": {
"shape": 1,
"dtype": "f32"
}
}
}
}
}

View File

@@ -59,24 +59,28 @@
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
}

View File

@@ -1,4 +1,5 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],

View File

@@ -85,7 +85,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The height of the output image.",
)
@@ -93,7 +93,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The width of the output image.",
)
@@ -420,6 +420,13 @@ p.add_argument(
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
@@ -460,6 +467,13 @@ p.add_argument(
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
p.add_argument(
"--autogen",
type=bool,
default="False",
help="Only used for a gradio workaround.",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -587,6 +601,13 @@ p.add_argument(
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
##############################################################################

View File

@@ -1,6 +1,10 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
@@ -10,6 +14,33 @@ from apps.stable_diffusion.src.utils.stencils import (
stencil = {}
def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
subdir = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
@@ -48,10 +79,12 @@ def controlnet_hint_shaping(
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
return controlnet_hint_shaping(
Image.fromarray(controlnet_hint.detach().numpy()),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
@@ -78,29 +111,36 @@ def controlnet_hint_shaping(
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
Image.fromarray(controlnet_hint),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, Image.Image):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
if controlnet_hint.size == (width, height):
controlnet_hint = np.array(controlnet_hint).astype(
np.float16
) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
(hint_w, hint_h) = controlnet_hint.size
left = int((hint_w - width) / 2)
right = left + height
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
controlnet_hint = controlnet_hint.resize((width, height))
return controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@@ -110,16 +150,22 @@ def controlnet_hint_conversion(
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
print(
"Converting controlnet hint to edge detection mask with canny preprocessor."
)
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
print(
"Detecting human pose in controlnet hint with openpose preprocessor."
)
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
print("Using your scribble as a controlnet hint.")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
print(
"Converting controlnet hint to a depth mapping with ZoeDepth."
)
controlnet_hint = hint_zoedepth(image)
case _:
return None
@@ -161,6 +207,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -176,6 +223,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -187,6 +235,7 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map
@@ -199,5 +248,6 @@ def hint_zoedepth(image: Image.Image):
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -30,9 +30,15 @@ class ZoeDetector:
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
# Hack to fix the ZoeDepth import issue
model_keys = model.state_dict().keys()
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
loaded_keys = loaded_dict.keys()
for key in loaded_keys - model_keys:
loaded_dict.pop(key)
model.load_state_dict(loaded_dict)
model.eval()
self.model = model

View File

@@ -118,7 +118,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,
@@ -541,6 +541,8 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if args.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -563,6 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags

View File

@@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

View File

@@ -75,11 +75,11 @@ if __name__ == "__main__":
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
config_gradio_tmp_imgs_folder,
from apps.stable_diffusion.web.utils.tmp_configs import (
config_tmp,
)
config_gradio_tmp_imgs_folder()
config_tmp()
import gradio as gr
# Create custom models folders if they don't exist
@@ -97,8 +97,6 @@ if __name__ == "__main__":
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
@@ -109,6 +107,16 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_png_info_img,
txt2img_sdxl_status,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
@@ -145,7 +153,7 @@ if __name__ == "__main__":
upscaler_sendto_outpaint,
# lora_train_web,
# model_web,
# model_config_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -159,6 +167,7 @@ if __name__ == "__main__":
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
@@ -172,7 +181,7 @@ if __name__ == "__main__":
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
@@ -183,7 +192,7 @@ if __name__ == "__main__":
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
@@ -193,12 +202,14 @@ if __name__ == "__main__":
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
gr.Tabs(selected=selectedid),
),
inputs,
outputs,
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
@@ -235,6 +246,7 @@ if __name__ == "__main__":
inpaint_status,
outpaint_status,
upscaler_status,
txt2img_sdxl_status,
]
)
# with gr.TabItem(label="Model Manager", id=6):
@@ -243,16 +255,18 @@ if __name__ == "__main__":
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
txt2img_sdxl_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:
@@ -391,6 +405,12 @@ if __name__ == "__main__":
[outputgallery_filename],
[upscaler_init_image, tabs],
)
register_outputgallery_button(
outputgallery_sendto_txt2img_sdxl,
0,
[outputgallery_filename],
[txt2img_sdxl_png_info_img, tabs],
)
register_modelmanager_button(
modelmanager_sendto_txt2img,
0,

View File

@@ -10,6 +10,18 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
txt2img_sdxl_png_info_img,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_web,
@@ -76,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,

View File

@@ -0,0 +1,55 @@
from apps.stable_diffusion.web.ui.utils import (
HSLHue,
hsl_color,
get_lora_metadata,
)
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]

View File

@@ -105,6 +105,18 @@ body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
@@ -246,10 +258,39 @@ footer {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
line-height: var(--line-xs);
}
.output_icon_button {

View File

@@ -5,6 +5,13 @@ import gradio as gr
import PIL
from math import ceil
from PIL import Image
from gradio.components.image_editor import (
Brush,
Eraser,
EditorData,
EditorValue,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -14,6 +21,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -29,6 +37,11 @@ from apps.stable_diffusion.src.utils import (
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -58,7 +71,6 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
@@ -66,6 +78,9 @@ def img2img_inf(
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -87,14 +102,23 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
if image_dict is None:
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")
if image_dict is None and images[0] is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
# TODO: enable t2i + controlnets
image = None
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
@@ -121,10 +145,11 @@ def img2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
stencil_count = 0
for stencil in stencils:
if stencil is not None:
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
@@ -148,7 +173,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
stencils=stencils,
ondemand=ondemand,
)
if (
@@ -170,12 +195,12 @@ def img2img_inf(
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
else "stabilityai/stable-diffusion-1-5-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
if stencil_count > 0:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
@@ -192,7 +217,7 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
@@ -249,8 +274,10 @@ def img2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
stencils,
images,
resample_type=resample_type,
control_mode=control_mode,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -270,12 +297,17 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
), stencils, images
return generated_imgs, text_output, ""
return generated_imgs, text_output, "", stencils, images
with gr.Blocks(title="Image-to-Image") as img2img_web:
# Stencils
# TODO: Add more stencils here
STENCIL_COUNT = 2
stencils = gr.State([None] * STENCIL_COUNT)
images = gr.State([None] * STENCIL_COUNT)
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -340,75 +372,282 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
tool="sketch",
type="pil",
height=300,
height=512,
interactive=True,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Accordion(label="Multistencil Options", open=False):
choices = [
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
def cnet_preview(
model, input_image, index, stencils, images
):
images[index] = input_image
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(
np.array(input_image["composite"]),
100,
200,
)
return (
Image.fromarray(result),
stencils,
images,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(
np.array(input_image["composite"])
)
print(result)
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
return (
Image.fromarray(result[0]),
stencils,
images,
)
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(
np.array(input_image["composite"])
)
return (
Image.fromarray(result),
stencils,
images,
)
case "scribble":
return (
input_image["composite"],
stencils,
images,
)
case _:
return (None, stencils, images)
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [data],
"composite": None,
}
return EditorValue(img_dict)
def update_cn_input(model, width, height):
if model == "scribble":
return [
gr.ImageEditor(
visible=True,
interactive=True,
show_label=False,
image_mode="RGB",
type="pil",
value=create_canvas(width, height),
brush=Brush(
colors=["#000000"], color_mode="fixed"
),
),
gr.Image(
visible=True,
show_label=False,
interactive=False,
show_download_button=False,
),
gr.Slider(visible=True),
gr.Slider(visible=True),
gr.Button(visible=True),
]
else:
return [
gr.ImageEditor(
visible=True,
image_mode="RGB",
type="pil",
interactive=True,
value=None,
),
gr.Image(
visible=True,
show_label=False,
interactive=True,
show_download_button=False,
),
gr.Slider(visible=False),
gr.Slider(visible=False),
gr.Button(visible=False),
]
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
with gr.Column():
cnet_1 = gr.Button(
value="Generate controlnet input"
)
cnet_1_model = gr.Dropdown(
label="Controlnet 1",
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
cnet_1_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
show_label=False,
type="pil",
)
cnet_1_output = gr.Image(
visible=True, show_label=False
)
cnet_1_model.input(
update_cn_input,
[cnet_1_model, canvas_width, canvas_height],
[
cnet_1_image,
cnet_1_output,
canvas_width,
canvas_height,
make_canvas,
],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
make_canvas.click(
update_cn_input,
[cnet_1_model, canvas_width, canvas_height],
[
cnet_1_image,
cnet_1_output,
canvas_width,
canvas_height,
make_canvas,
],
)
cnet_1.click(
fn=(
lambda a, b, s, i: cnet_preview(a, b, 0, s, i)
),
inputs=[
cnet_1_model,
cnet_1_image,
stencils,
images,
],
outputs=[cnet_1_output, stencils, images],
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
with gr.Column():
cnet_2 = gr.Button(
value="Generate controlnet input"
)
cnet_2_model = gr.Dropdown(
label="Controlnet 2",
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
make_canvas = gr.Button(
value="Make Canvas!",
visible=False,
)
cnet_2_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
interactive=True,
show_label=False,
type="pil",
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2_output = gr.Image(
visible=True, show_label=False
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
cnet_2_model.select(
update_cn_input,
[cnet_2_model, canvas_width, canvas_height],
[
cnet_2_image,
cnet_2_output,
canvas_width,
canvas_height,
make_canvas,
],
)
make_canvas.click(
update_cn_input,
[cnet_2_model, canvas_width, canvas_height],
[
cnet_2_image,
cnet_2_output,
canvas_width,
canvas_height,
make_canvas,
],
)
cnet_2.click(
fn=(
lambda a, b, s, i: cnet_preview(a, b, 1, s, i)
),
inputs=[
cnet_2_model,
cnet_2_image,
stencils,
images,
],
outputs=[cnet_2_output, stencils, images],
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Accordion(label="LoRA Options", open=False):
@@ -436,6 +675,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -610,7 +854,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
@@ -618,8 +861,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
ondemand,
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
],
outputs=[
img2img_gallery,
std_output,
img2img_status,
stencils,
images,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)
@@ -638,3 +890,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -4,6 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -13,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -120,7 +122,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -288,8 +290,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
inpaint_init_image = gr.Image(
label="Masked Image",
source="upload",
tool="sketch",
sources="upload",
type="pil",
height=350,
)
@@ -319,6 +320,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -518,3 +524,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -104,7 +104,6 @@ with gr.Blocks() as model_web:
civit_models = gr.Gallery(
label="Civitai Model Gallery",
value=None,
interactive=True,
visible=False,
)

View File

@@ -3,9 +3,8 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -122,7 +121,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -323,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -546,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -91,11 +91,11 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
columns=4,
)
with gr.Column(scale=4):
with gr.Box():
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
@@ -152,6 +152,7 @@ with gr.Blocks() as outputgallery_web:
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
@@ -162,6 +163,12 @@ with gr.Blocks() as outputgallery_web:
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_txt2img_sdxl = gr.Button(
value="Txt2Img XL",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
@@ -195,15 +202,18 @@ with gr.Blocks() as outputgallery_web:
def on_clear_gallery():
return [
gr.Gallery.update(
gr.Gallery(
value=[],
visible=False,
),
gr.Image.update(
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
@@ -212,12 +222,12 @@ with gr.Blocks() as outputgallery_web:
)
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -251,16 +261,16 @@ with gr.Blocks() as outputgallery_web:
)
return [
gr.Dropdown.update(
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images, label=new_label, visible=len(new_images) > 0
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -286,12 +296,12 @@ with gr.Blocks() as outputgallery_web:
return [
new_images,
gr.Gallery.update(
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image.update(
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
@@ -329,12 +339,12 @@ with gr.Blocks() as outputgallery_web:
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
@@ -365,53 +375,6 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
@@ -423,38 +386,42 @@ with gr.Blocks() as outputgallery_web:
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
@@ -477,7 +444,7 @@ with gr.Blocks() as outputgallery_web:
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
@@ -489,4 +456,4 @@ with gr.Blocks() as outputgallery_web:
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
@@ -132,27 +133,6 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -193,6 +173,13 @@ def chat(
get_vulkan_target_triple,
)
_extra_args = _extra_args + [
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
@@ -270,10 +257,11 @@ def chat(
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
# for text, msg, exec_time in progress.tqdm(
# vicuna_model.generate(prompt, cli=cli),
# desc="generating response",
# ):
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
if msg is None:
if is_first:
prefill_time = exec_time
@@ -451,12 +439,12 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button = gr.Button(value="View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
chatbot = gr.Chatbot(height=500)
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row():
with gr.Column():
msg = gr.Textbox(

View File

@@ -0,0 +1,649 @@
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_sdxl_models,
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_sdxl_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
if precision != "fp16":
print("currently we support fp16 for SDXL")
precision = "fp16"
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = model_id
if custom_vae:
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img_sdxl",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
stencils=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
global_obj.set_sd_obj(
Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image-SDXL",
current_batch + 1,
batch_count,
batch_size,
)
return generated_imgs, text_output, ""
theme = gr.themes.Glass(
primary_hue="slate",
secondary_hue="gray",
)
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_sdxl_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-xl-base-1.0",
choices=predefined_sdxl_models
+ get_custom_model_files(
custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=2,
)
t2i_sdxl_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_sdxl_vae_info = (
f"VAE Path: {t2i_sdxl_vae_info}"
)
custom_vae = gr.Dropdown(
label=f"VAE Models",
info=t2i_sdxl_vae_info,
elem_id="custom_model",
value="None",
choices=[
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_sdxl_png_info_img = gr.Image(
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
visible=True,
)
with gr.Group(elem_id="prompt_box_outer"):
txt2img_sdxl_autogen = gr.Checkbox(
label="Auto-Generate Images",
value=False,
visible=False,
)
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_sdxl_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_sdxl_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
choices=[
"DDIM",
"EulerAncestralDiscrete",
"EulerDiscrete",
"LCMScheduler",
],
allow_custom_value=False,
visible=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
512,
1024,
value=1024,
step=256,
label="Height",
visible=True,
interactive=True,
)
width = gr.Slider(
512,
1024,
value=1024,
step=256,
label="Width",
visible=True,
interactive=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=77,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="Guidance Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=False,
visible=False,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_sdxl_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="scale_down",
)
std_output = gr.Textbox(
value=f"{t2i_sdxl_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_sdxl_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
txt2img_sdxl_sendto_img2img = gr.Button(
value="Send To Img2Img",
visible=False,
)
txt2img_sdxl_sendto_inpaint = gr.Button(
value="Send To Inpaint",
visible=False,
)
txt2img_sdxl_sendto_outpaint = gr.Button(
value="Send To Outpaint",
visible=False,
)
txt2img_sdxl_sendto_upscaler = gr.Button(
value="Send To Upscaler",
visible=False,
)
kwargs = dict(
fn=txt2img_sdxl_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_sdxl_custom_model,
custom_vae,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
queue=True,
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
concurrency_limit=1,
)
def autogen_changed(checked):
if checked:
args.autogen = True
else:
args.autogen = False
def check_last_input(prompt):
if not prompt.endswith(" "):
return True
elif not args.autogen:
return True
else:
return False
auto_gen_kwargs = dict(
fn=check_last_input,
inputs=[negative_prompt],
outputs=[txt2img_sdxl_status],
concurrency_limit=1,
)
txt2img_sdxl_autogen.change(
fn=autogen_changed,
inputs=[txt2img_sdxl_autogen],
outputs=None,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[
prompt_submit,
neg_prompt_submit,
generate_click,
],
)
txt2img_sdxl_png_info_img.change(
fn=import_png_metadata,
inputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_sdxl_png_info_img,
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
seed,
width,
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
)
txt2img_sdxl_custom_model.change(
fn=set_model_default_configs,
inputs=[
txt2img_sdxl_custom_model,
],
outputs=[
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
width,
height,
custom_vae,
txt2img_sdxl_autogen,
],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -5,6 +5,7 @@ import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -15,6 +16,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -124,7 +126,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -224,7 +226,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil="None",
stencils=[],
ondemand=ondemand,
)
@@ -278,7 +280,8 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
stencils=[],
control_mode=None,
resample_type=resample_type,
)
total_time = time.time() - start_time
@@ -300,7 +303,17 @@ def txt2img_inf(
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image") as txt2img_web:
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -354,7 +367,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Import PNG info",
elem_id="txt2img_prompt_image",
type="pil",
tool="None",
visible=True,
)
@@ -365,6 +377,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lines=2,
elem_id="prompt_box",
)
# TODO: coming soon
autogen = gr.Checkbox(
label="Continuous Generation",
visible=False,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
@@ -396,6 +413,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -673,12 +695,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
return gr.Dropdown(
choices=scheduler_list,
value="SharkEulerDiscrete",
)
@@ -689,3 +711,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
outputs=[scheduler],
queue=False,
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -12,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -118,7 +120,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -340,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -537,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,10 +1,17 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
import math
import json
import safetensors
import gradio as gr
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
from enum import IntEnum
from apps.stable_diffusion.src import get_available_devices
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
@@ -24,10 +31,19 @@ class Config:
width: int
device: str
use_lora: str
use_stencil: str
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
@@ -49,9 +65,11 @@ scheduler_list_cpu_only = [
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
"LCMScheduler",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
"SharkEulerAncestralDiscrete",
]
predefined_models = [
@@ -72,6 +90,10 @@ predefined_paint_models = [
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
]
predefined_sdxl_models = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
def resource_path(relative_path):
@@ -125,6 +147,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
)
]
match custom_checkpoint_type:
case "sdxl":
files = [
val
for val in files
if any(x in val for x in ["XL", "xl", "Xl"])
]
case "inpainting":
files = [
val
@@ -161,6 +189,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
return use_weight
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_custom_model_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
@@ -169,6 +260,99 @@ def cancel_sd():
pass
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
import gradio as gr
config_modelname = default_config_exists(model_ckpt_or_id)
if jsonconfig:
return get_config_from_json(jsonconfig)
elif config_modelname:
return default_configs[config_modelname]
# TODO: Use HF metadata to setup pipeline if available
# elif is_valid_hf_id(model_ckpt_or_id):
# return get_HF_default_configs(model_ckpt_or_id)
else:
# We don't have default metadata to setup a good config. Do not change configs.
return [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
]
def get_config_from_json(model_ckpt_or_id, jsonconfig):
# TODO: make this work properly. It is currently not user-exposed.
cfgdata = json.load(jsonconfig)
return [
cfgdata["prompt_box_behavior"],
cfgdata["neg_prompt_box_behavior"],
cfgdata["steps"],
cfgdata["scheduler"],
cfgdata["guidance_scale"],
cfgdata["width"],
cfgdata["height"],
cfgdata["custom_vae"],
]
def default_config_exists(model_ckpt_or_id):
if model_ckpt_or_id in [
"stabilityai/sdxl-turbo",
"stabilityai/stable_diffusion-xl-base-1.0",
]:
return model_ckpt_or_id
elif "turbo" in model_ckpt_or_id.lower():
return "stabilityai/sdxl-turbo"
else:
return None
default_configs = {
"stabilityai/sdxl-turbo": [
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
),
gr.Slider(0, 10, value=2),
gr.Dropdown(value="EulerAncestralDiscrete"),
gr.Slider(0, value=0),
512,
512,
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate", visible=False, interactive=True, value=False
),
],
"stabilityai/stable-diffusion-xl-base-1.0": [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
40,
"EulerDiscrete",
7.5,
gr.Slider(value=768, interactive=True),
gr.Slider(value=768, interactive=True),
"madebyollin/sdxl-vae-fp16-fix",
gr.Checkbox(
label="Auto-Generate",
visible=False,
interactive=False,
value=False,
),
],
}
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()

View File

@@ -5,11 +5,25 @@ from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
def clear_tmp_mlir():
cleanup_start = time()
print(
"Clearing .mlir temporary files from a prior run. This may take some time..."
)
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
print(
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
)
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -78,7 +78,10 @@ def test_loop(
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
@@ -112,6 +115,8 @@ def test_loop(
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this

View File

@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:

View File

@@ -26,7 +26,7 @@ sacremoses
sentencepiece
# web dependecies.
gradio
gradio==3.44.3
altair
scipy

View File

@@ -26,7 +26,7 @@ diffusers
accelerate
scipy
ftfy
gradio==3.44.3
gradio==4.7.1
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
@@ -50,4 +50,8 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
# For quantized GPTQ models
optimum
auto_gptq

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html

View File

@@ -111,7 +111,7 @@ else
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi

View File

@@ -31,60 +31,64 @@ from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] == "vulkan":
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
@@ -351,11 +355,15 @@ def get_iree_module(
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config.id = hal_device_id
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
@@ -394,15 +402,16 @@ def load_vmfb_using_mmap(
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
config.id = hal_device_id
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)

View File

@@ -95,6 +95,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_num = 0 if device_num is None else device_num
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
@@ -103,7 +104,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print(f"Found ROCm device arch : {arch_in_device_dump}")
return arch_in_device_dump
default_rocm_arch = "gfx_1100"
default_rocm_arch = "gfx1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."

View File

@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
if isinstance(device_num, int):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: vulkan://{device_num}")
vulkan_device_name = vulkaninfo_list[device_num]
else:
from iree.runtime import get_driver
vulkan_device_driver = get_driver(device_num)
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
print(vulkan_device_name)
return vulkan_device_name
def get_os_name():
@@ -174,6 +183,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
res_vulkan_flag += [
"--iree-stream-resource-max-allocation-size=3221225472"
]
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:
@@ -195,7 +207,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}",
f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}"
f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -7,7 +7,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -84,7 +84,7 @@ def compile_int_precision(
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,

View File

@@ -800,15 +800,17 @@ def save_mlir(
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
dir="",
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):
os.makedirs(dir)
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)

View File

@@ -1,21 +1,19 @@
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
1 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False False
2 bert-base-uncased_fp16 bert-large-uncased linalg torch 1e-1 1e-2 1e-1 1e-3 default None True False True True False False
3 bert-large-uncased facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default None nhcw-nhwc False True False False Fails during iree-compile.
4 facebook/deit-small-distilled-patch16-224 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True True False True True Fails during iree-compile. https://github.com/nod-ai/SHARK/issues/311
5 google/vit-base-patch16-224 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/311 https://github.com/nod-ai/SHARK/issues/390 macos
6 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default nhcw-nhwc None False True False False https://github.com/nod-ai/SHARK/issues/390 macos
7 microsoft/MiniLM-L12-H384-uncased google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False False https://github.com/nod-ai/SHARK/issues/344 macos
8 google/mobilebert-uncased mobilenet_v3_small linalg torch 1e-2 1e-1 1e-3 1e-2 default None nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/344 https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc True True True https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
9 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True True True https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487 macos
10 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
11 resnet18 linalg torch 1e-2 1e-3 default None True True False True True macos
12 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True True True Numerics issues, awaiting cuda-independent fp16 integration
13 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
14 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
15 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True False True True False False macos
16 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
17 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
18 t5-base linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported. macos
19 t5-large linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported macos

View File

@@ -50,7 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
is_decompose = row[5]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
is_dynamic = False
print("generating artifacts for: " + torch_model_name)
model = None
input = None
@@ -104,7 +104,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model_name=torch_model_name,
mlir_type=mlir_type,
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
)
else:
mlir_importer = SharkImporter(
@@ -114,7 +114,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
@@ -123,7 +123,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,