Compare commits

..

53 Commits

Author SHA1 Message Date
Ean Garvey
ee0233e370 Fix formatting. 2023-11-13 20:01:28 -06:00
Daniel Garvey
a3deeec870 Dan shark studio (#1970)
* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-13 19:07:28 -06:00
Stefan Kapusniak
c2163488d8 SD/UI Restrict hires fix/img2img resamplers/schedulers (#1955)
* Restrict resamplers for img2img and high res fix to the ones that
PIL.Image actually supports, since it uses that to di the resampling.
Removed: Antialias, Affine, Cubic. Added: Hamming.
* Set list of available schedulers to CPU only when high res fix
is selected in the web ui. Set list to all schdulers when high res fix
is deselected.
* Put hi res fix in its own Accordian in the txt2img UI instead of
grouping it with Advanced Options.
2023-11-13 16:08:24 -06:00
PhaneeshB
54bff4611d fix cli rocm device selection 2023-11-13 23:35:55 +05:30
PhaneeshB
11510d5111 add intra rocm vmfb differentiator 2023-11-13 23:35:55 +05:30
PhaneeshB
32cab73a29 add iree-rocm-target-chip only if added by user 2023-11-13 23:35:55 +05:30
PhaneeshB
392bade0bf enable non default rocm device selection for webui 2023-11-13 23:35:55 +05:30
Stefan Kapusniak
91df5f0613 API/Docs: Fix an image link in koboldcpp doc (#1954)
* Fix the image link for the koboldcpp style button pointing to the
dialog image rather than the button image.
2023-11-13 11:14:29 -06:00
dependabot[bot]
df20cf9c8a Bump langchain in /apps/language_models/langchain (#1968)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.325 to 0.0.329.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.325...v0.0.329)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-12 19:46:00 -08:00
Ean Garvey
c4a908c3ea Pin pydantic to 2.4.1 in requirements (#1967)
pyinstaller-hooks-contrib doesn't see beta versions of pydantic as versions greater than 2.0.0, and so it looks for an attribute `compile` only available in versions older than 2.0.0 if you have a beta version of pydantic.
2023-11-10 21:34:52 -06:00
Stefan Kapusniak
6285430d8a UI: Fix webui launch on non-Windows (#1963)
* Moves the imports of winreg and Tk, into the functions that use them,
with winreg behind a guard clause. This should hopefully mean that if
you're not on Window or not using `ui=app` we won't trip over either
of these due to them not being there.
2023-11-10 16:38:32 -06:00
PhaneeshB
51afe19e20 fix rocm arch selection 2023-11-10 13:22:51 +05:30
Ean Garvey
31005bcf73 Don't require vulkan installation to query devices. (#1953) 2023-11-09 14:46:44 -06:00
dependabot[bot]
f41ad87ef6 Bump langchain in /apps/language_models/langchain (#1926)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.202 to 0.0.325.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.202...v0.0.325)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:03:47 -06:00
dependabot[bot]
d811524a00 Bump pypdf from 3.12.2 to 3.17.0 in /apps/language_models/langchain (#1929)
Bumps [pypdf](https://github.com/py-pdf/pypdf) from 3.12.2 to 3.17.0.
- [Release notes](https://github.com/py-pdf/pypdf/releases)
- [Changelog](https://github.com/py-pdf/pypdf/blob/main/CHANGELOG.md)
- [Commits](https://github.com/py-pdf/pypdf/compare/3.12.2...3.17.0)

---
updated-dependencies:
- dependency-name: pypdf
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:02:43 -06:00
Sungsoon Cho
51e1bd1c5d (OPT) Fix typo in the message; s/reponse/response (#1920) 2023-11-09 11:00:48 -06:00
Phaneesh Barwaria
db89b1bdc1 Fix MacOS web execution flow (#1899)
* fix metal device path for chatbot

* single device remove indexing

* lint fix
2023-11-09 10:59:29 -06:00
Huang Qi
2754e2e257 Fix wrong parameter index passed to 'compile_module_to_flatbuffer' (#1921)
compile_str is always False in compile_module_to_flatbuffer since there
is a parameter 'model_name' before 'debug'.

This issue is relative to https://github.com/nod-ai/SHARK/pull/1863.

Then we can use mlir model buffer in RAM to run inference.
2023-11-09 10:58:05 -06:00
PhaneeshB
ab0e870c43 fix vicuna cli vulkan 2023-11-09 22:27:13 +05:30
Stefan Kapusniak
fb30e8c226 UI: Fix some webui launch corner cases (#1952)
* On windows insist on the presence of webview2 as the embeddable
browser for `ui=app`. If we can't find it, effectively switch back to
`ui=web`. This should prevent pywebview trying to use MSHTML, whilst
saying its deprecated, and apparently we are too much for poor old IE11
* Add webview2 runtime droppings to .gitignore.
* If we can't bind to args.server_port get another suitable port from
the OS and advise the user that we did this in the UI.
* Make `ui=web` mode use 'SHARK AI Studio' as its title. This makes it
consistent with `ui=app`.
* Replace the generic gradio favicon with a nod swirl one instead.
2023-11-09 10:53:28 -06:00
Ean Garvey
a07d542400 (Studio) Disable SD tunings and sub-model downloads (#1944)
* sets --no-use_tuned and --import_mlir as defaults in SHARK Studio.
2023-11-07 15:55:30 -06:00
Stefan Kapusniak
ad55cb696f SD/API: Add missing A1111 APIs to Shark to support koboldcpp image generation (#1924)
* SD/API: Add missing a1111 API features for Koboldcpp

* Refactors SD api functions into their own file
* Adds the following apis implemented by a1111 as needed by koboldcpp:
   - adds /sdapi/v1/sd-models (lists available models)
   - adds /sdapi/v1/options (only the bare minimum needed)
* Adds optional CORS support, use the '--api_accept_origin' command line
argument to activate and configure.
* Extends existing APIs to include optional sampler/scheduler selection
* Extends /sdapi/v1/textimg to recognise the method used by koboldcpp
to select the model.
* Where possible take values not provided to the API in the request from
the existing relevant command line parameters rather than hardcoding
them.
* return a 400 response when a request doesn't have required properties.
* changed default schedulers and models for some apis to ones that
actually seem to work.
* Update api_test.py to include the new APIs.
* Update api_test.py to include a '--verbose' command line option.

* SD/API: Take more API values from args

* Take LoRA from '--use_lora' command line arg if specified
* Take device from '--device' command line arg if specified (substring
match, so a short name such as 'vulkan://0' should work)

* SD/API: add more endpoints and pydantic typing

* Mount the whole of /sdapi from index.py as a FastAPI application,
rather than each endpoint individually
* Add the following additional API endpoints:
  * /sdapi/v1/samplers
  * /sdapi/v1/cmd-flags
* Make scheduler/sampler selection checking and fallback much more
robust.
* Support aliasing some A1111 scheduler/sampler names to the diffusers
ones we are using.
* Expand response /sdapi/v1/options to add a few more things.
* Split non-api functions and variables into their own utils.py file.
* Support 'n_iter' request property and the return of multiple images
from generation endpoints. Equivalent of '--batch_count', batch_size
is stil hardcoded at 1
* Include (some) hires_fix request properties in txt2img endpoint
* Rework endpoints using pydantic model classes for better request
validation and so we get much improved swagger api docs at
/sdapi/docs and redoc at /sdapi/redoc

* SD/API Delete commented out code from index.py

* Delete some code that is no longer needed by the SD API in index.py
(and one line sdapi_v1.py) that I'd previously only commented out.

* SD/UI: Add shark_sd_koboldcpp.md document

* Add documentation on how to set up Koboldcpp with SHARK
* Link this and the existing blender set up document from the main
README.md

* SD/API Improve stencil options in img2img endpoint

In /sdapi/v1/img2img:
  * Add zoedepth to the controlnet use_stencil options
  * Require and use second image as stencil mask for controlnet scribble
2023-11-06 15:20:19 -06:00
Jakub Kuderski
488a172292 [vicuna.py] Allow to pass extra arguments to iree-compile (#1935)
Add a new flag `-Xiree_compile` to forward extra compiler arguments to
`iree-compile`. This flag can be set multiple times to pass more than
one extra argument.
2023-11-06 12:12:34 -05:00
Stanley Winata
500c4f2306 [compile utils] Fix ROCM to not expect config.id as a default. (#1939) 2023-11-06 08:44:53 -08:00
Vivek Khandelwal
92b694db4d Add support for Falcon-40b-GPTQ 2023-11-06 19:49:19 +05:30
Vivek Khandelwal
322874f7f9 Fix issue in Falcon-GPTQ 2023-11-03 11:48:36 +05:30
Ean Garvey
5001db3415 Add 7800xt to target triples explicitly. (#1928) 2023-11-01 17:11:45 -05:00
Vivek Khandelwal
71846344a2 Add sharded Falcon-GPTQ support
This commit adds the support for sharded Falcon-7b-GPTQ and
Falcon-180B-GPTQ. This commit also adds the support for 4-way
sharding of the Falcon model for the device ROCM.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-11-01 12:11:44 +05:30
gpetters94
72e27c96fc Add ZoeDepth (#1834)
* Add ZoeDepth

* Add einops to Studio imports.

* Specify ref for forked torch.hub repos.

* Unpin timm.

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Ean Garvey <garveyej@gmail.com>
2023-10-30 11:57:45 -05:00
PhaneeshB
7963abb8ec remove caching for rocm args 2023-10-29 07:07:57 +05:30
Ean Garvey
98244232dd Add smoothquant OPT to examples. (#1922) 2023-10-27 12:32:12 -05:00
PhaneeshB
679a452139 fix calls and remove unused imports for check_device_drivers 2023-10-27 10:30:40 +05:30
PhaneeshB
72c0a8abc8 remove dependency on external commands for driver installation check 2023-10-27 10:30:40 +05:30
Vivek Khandelwal
ea920f2955 Add sharded Falcon support 2023-10-26 21:53:25 +05:30
Phaneesh Barwaria
486202377a update dependency on rocm/hip info command (#1900)
* add support for rocm flags

* add rocm target flag to chat args

* rm rocm libs dependency message
2023-10-26 15:18:25 +05:30
Sungsoon Cho
0c38c33d0a Add opt_causallm_samples.py. (#1916) 2023-10-25 11:52:51 -05:00
Ean Garvey
841773fa32 Updates to opt_causallm example (#1905)
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
2023-10-24 10:54:39 -07:00
Stefan Kapusniak
0361db46f9 SD: Fix unet untuned opt_flags (#1912)
* correct my sloppy copy/paste for the untuned unet default compilation
flags that introduced an extra 'detach' into what should have been
'iree-global-opt-convert-1x1-filter-conv2d-to-matmul'
2023-10-24 12:47:33 -05:00
xzuyn
a012433ffd Save hiresfix info if used (#1914) 2023-10-24 12:45:10 -05:00
xzuyn
5061193da3 Move Generate, Randomize Seed, & Stop Batch to same positions as txt2img (#1915) 2023-10-24 12:44:39 -05:00
xzuyn
bff48924be LLaMa 2 Chat template fix (#1913) 2023-10-23 18:51:15 -05:00
Stefan Kapusniak
825b36cbdd Fix MLIR Textual PassPipeline Error (#1910) 2023-10-22 07:39:52 -07:00
Stefan Kapusniak
134441957d SD - Fix civitai download on Windows +improvements (#1907) 2023-10-21 11:17:41 -07:00
Stefan Kapusniak
7cd14fdc47 SD/UI: Use a single model selection box on UI tabs (#1906)
* Allow entry of a huggingface model id or civitai download url to be
done in the main model selection dropdown on SD tabs
* Remove separate textbox for entering huggingface model id or civitai
download url on SD Tabs
* Remove 'None' option from the model selection dropdown (no longer
needed) on SD tabs
* Update png metadata drop zone on txt2img tab to work with a single
argument for model selection
* Update UI generate functions on SD tabs to work with single argument
model selection
* Update API code for changes to the UI generate functions
* Move info about the custom model path to the logging textarea on SD
tabs
2023-10-21 10:06:05 -07:00
Ean Garvey
e6cb5cef57 Add --additional_runtime_args option and use in OPT example. (#1855)
* Add --additional_runtime_args option and use in OPT example.

Fix the func name. (#1838)

Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
2023-10-19 13:29:39 -05:00
Huang Qi
66abee8e5b SharkInference: Fix various examples and README.md (#1903)
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
2023-10-19 09:28:36 -05:00
Ean Garvey
4797bb89f5 Stringify path for ireec.compile_file (#1901)
* Stringify path for ireec.compile_file

* Update test-models.yml
2023-10-18 14:59:23 -05:00
Vivek Khandelwal
205e57683a Modify Falcon-180b-GPTQ sharded pipeline 2023-10-17 20:26:01 +05:30
Vivek Khandelwal
2866d665ee Fix Sharded Falcon-180b-GPTQ Pipeline 2023-10-17 20:26:01 +05:30
Stefan Kapusniak
71d25ec5d8 SD: Fix repeatable seeds when intial seed is random (#1893) 2023-10-14 22:50:42 -07:00
Vivek Khandelwal
202ffff67b Add support for sharded Falcon model 2023-10-13 22:05:10 +05:30
Ean Garvey
0b77059628 Add matmul reassociation flags (#1891) 2023-10-12 20:12:37 -05:00
Stefan Kapusniak
a208302bb9 Fix repeatable seeds consistency over batch counts (#1889)
* Set the input seed for the random number generator when
generating repeatable seeds to exclude any negative numbers
in the parsed seed input.  The makes seeds generated for
different batch counts consistent where they have the same
input for the initial seed or set of seeds.
2023-10-12 17:15:19 -05:00
70 changed files with 4614 additions and 1194 deletions

View File

@@ -137,7 +137,8 @@ jobs:
source shark.venv/bin/activate
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
# disabled due to a low-visibility memory issue with pytest on macos.
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

5
.gitignore vendored
View File

@@ -182,7 +182,7 @@ generated_imgs/
# Custom model related artefacts
variants.json
models/
/models/
# models folder
apps/stable_diffusion/web/models/
@@ -199,3 +199,6 @@ apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

View File

@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -320,12 +319,17 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:

View File

@@ -65,8 +65,8 @@ tiktoken==0.4.0
openai==0.27.8
# optional for chat with PDF
langchain==0.0.202
pypdf==3.12.2
langchain==0.0.329
pypdf==3.17.0
# avoid textract, requires old six
#textract==1.6.5

View File

@@ -137,6 +137,12 @@ parser.add_argument(
default="",
help="Specify target triple for vulkan.",
)
parser.add_argument(
"--Xiree_compile",
action='append',
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
)
# Microbenchmarking options.
parser.add_argument(
@@ -1253,6 +1259,7 @@ class UnshardedVicuna(VicunaBase):
max_num_tokens=512,
min_num_tokens=0,
device="cpu",
device_id=None,
vulkan_target_triple="",
precision="int8",
vicuna_mlir_path=None,
@@ -1263,7 +1270,6 @@ class UnshardedVicuna(VicunaBase):
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
device_id=None,
debug=False,
) -> None:
super().__init__(
@@ -1282,9 +1288,7 @@ class UnshardedVicuna(VicunaBase):
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.min_num_tokens = min_num_tokens
self.device = device
self.vulkan_target_triple = vulkan_target_triple
self.device_id = device_id
self.precision = precision
self.download_vmfb = download_vmfb
self.vicuna_vmfb_path = vicuna_vmfb_path
@@ -1293,26 +1297,53 @@ class UnshardedVicuna(VicunaBase):
self.low_device_memory = low_device_memory
self.weight_group_size = weight_group_size
self.debug = debug
# Sanity check for device, device_id pair
if "://" in device:
if device_id is not None:
print("[ERR] can't have both full device path and a device id.\n"
f"Device : {device} | device_id : {device_id}\n"
"proceeding with given Device ignoring device_id")
self.device, self.device_id = device.split("://")
if len(self.device_id) < 2:
self.device_id = int(self.device_id)
else:
self.device, self.device_id = device, device_id
if self.vicuna_mlir_path == None:
self.vicuna_mlir_path = self.get_model_path()
if self.vicuna_vmfb_path == None:
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
self.tokenizer = self.get_tokenizer()
self.cache_vicunas = cache_vicunas
self.compile()
def get_model_path(self, suffix="mlir"):
safe_device = self.device.split("-")[0]
safe_device = safe_device.split("://")[0]
if suffix in ["mlirbc", "mlir"]:
return Path(f"{self.model_name}_{self.precision}.{suffix}")
target_triple = ""
if self.vulkan_target_triple != "":
target_triple = "_"
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
# Need to distinguish between multiple vmfbs of the same model
# compiled for different devices of the same driver
# Driver - Differentiator
# Vulkan - target_triple
# ROCm - device_arch
differentiator = ""
if "vulkan" == self.device:
target_triple = ""
if self.vulkan_target_triple != "":
target_triple = "_"
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
differentiator = target_triple
elif "rocm" == self.device:
from shark.iree_utils.gpu_utils import get_rocm_device_arch
device_arch = get_rocm_device_arch(self.device_id if self.device_id is not None else 0, self.extra_args)
differentiator = '_' + device_arch
return Path(
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
f"{self.model_name}_{self.precision}_{safe_device}{differentiator}.{suffix}"
)
def get_tokenizer(self):
@@ -1407,8 +1438,8 @@ class UnshardedVicuna(VicunaBase):
elif "llama2_70b" in self.model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x?x32x128x"
if self.device!="cpu:" : #precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape = "tensor<1x32x?x128x"
if self.precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
@@ -1416,9 +1447,9 @@ class UnshardedVicuna(VicunaBase):
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 1 : index")
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c1 : {pkv_tensor_shape}"
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
@@ -1480,8 +1511,6 @@ class UnshardedVicuna(VicunaBase):
mlir_generated = True
break
print(self.device)
print(self.device=="cpu")
if not mlir_generated:
print(f"[DEBUG] mlir not found")
@@ -1509,7 +1538,7 @@ class UnshardedVicuna(VicunaBase):
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32",
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
@@ -1518,13 +1547,13 @@ class UnshardedVicuna(VicunaBase):
model = FirstVicunaGPU(
self.hf_model_path,
self.precision,
"fp16",
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.device!="cpu"
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
@@ -1607,7 +1636,7 @@ class UnshardedVicuna(VicunaBase):
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, 19, dim1, 128], dtype=torch.float32))
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
@@ -1668,7 +1697,7 @@ class UnshardedVicuna(VicunaBase):
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.device!="cpu"
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
@@ -1678,7 +1707,7 @@ class UnshardedVicuna(VicunaBase):
mlir_type="torchscript",
)
del model
if self.device != "cpu":
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
@@ -1688,7 +1717,7 @@ class UnshardedVicuna(VicunaBase):
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[1]
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
@@ -1741,15 +1770,14 @@ class UnshardedVicuna(VicunaBase):
)
combined_module = save_mlir(
combined_module,
model_name="self.vicuna_mlir_path",
model_name="combined_llama",
mlir_dialect="tm_tensor",
dir=str(os.getcwd()),
dir=self.vicuna_mlir_path,
)
del first_module, second_module
print(self.device)
if "rocm" in self.device:
self.device = "rocm"
print(f"Compiling for device : {self.device}"
f"{'://' + str(self.device_id) if self.device_id is not None else ''}")
shark_module = SharkInference(
mlir_module=combined_module,
device=self.device,
@@ -1912,7 +1940,8 @@ def create_prompt(model_name, history):
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
_extra_args = []
_extra_args = list(args.Xiree_compile)
device_id = None
# Process vulkan target triple.
# TODO: This feature should just be in a common utils for other LLMs and in general
@@ -1968,6 +1997,7 @@ if __name__ == "__main__":
max_num_tokens=max_tokens,
min_num_tokens=min_tokens,
device=args.device,
vulkan_target_triple=vulkan_target_triple,
precision=args.precision,
vicuna_mlir_path=vic_mlir_path,
vicuna_vmfb_path=vic_vmfb_path,

View File

@@ -0,0 +1,598 @@
import torch
from typing import Optional, Tuple
class WordEmbeddingsLayer(torch.nn.Module):
def __init__(self, word_embedding_layer):
super().__init__()
self.model = word_embedding_layer
def forward(self, input_ids):
output = self.model.forward(input=input_ids)
return output
class CompiledWordEmbeddingsLayer(torch.nn.Module):
def __init__(self, compiled_word_embedding_layer):
super().__init__()
self.model = compiled_word_embedding_layer
def forward(self, input_ids):
input_ids = input_ids.detach().numpy()
new_input_ids = self.model("forward", input_ids)
new_input_ids = new_input_ids.reshape(
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
)
return torch.tensor(new_input_ids)
class LNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class LMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, embedding_layer):
super().__init__()
self.model = embedding_layer
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, lm_head):
super().__init__()
self.model = lm_head
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
def forward(self, hidden_states, attention_mask):
output = self.model.forward(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
return (output[0], output[1][0], output[1][1])
class CompiledDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
new_hidden_states, pkv1, pkv2 = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),
tuple(
[
torch.tensor(pkv1),
torch.tensor(pkv2),
]
),
]
)
class EightDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
if self.falcon_variant == "7b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "40b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class CompiledEightDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
if self.falcon_variant == "7b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
)
elif self.falcon_variant == "40b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class ShardedFalconModel:
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
super().__init__()
self.model = model
self.model.transformer.h = torch.nn.modules.container.ModuleList(
layers
)
self.model.transformer.word_embeddings = word_embeddings
self.model.transformer.ln_f = ln_f
self.model.lm_head = lm_head
def forward(
self,
input_ids,
attention_mask=None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
).logits[:, -1, :]

View File

@@ -52,8 +52,8 @@ class FirstVicuna(torch.nn.Module):
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
@@ -295,9 +295,6 @@ class SecondVicuna7B(torch.nn.Module):
i64,
),
)
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
past_key_values = tuple(past_key_values)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
@@ -306,8 +303,8 @@ class SecondVicuna7B(torch.nn.Module):
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)

View File

@@ -1,4 +1,17 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
WordEmbeddingsLayer,
CompiledWordEmbeddingsLayer,
LNFEmbeddingLayer,
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
@@ -16,12 +29,13 @@ from transformers.generation import (
StoppingCriteriaList,
)
import copy
import time
import re
import torch
import torch_mlir
import os
import argparse
import gc
parser = argparse.ArgumentParser(
prog="falcon runner",
@@ -31,6 +45,12 @@ parser = argparse.ArgumentParser(
parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--compressed",
default=False,
action=argparse.BooleanOptionalAction,
help="Do the compression of sharded layers",
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
@@ -67,9 +87,16 @@ parser.add_argument(
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
parser.add_argument(
"-s",
"--sharded",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
class Falcon(SharkLLMBase):
class ShardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
@@ -85,6 +112,532 @@ class Falcon(SharkLLMBase):
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if (
"180b" in self.model_name
and precision != "int4"
and hf_auth_token == None
):
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile(compressed=args.compressed)
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(
self, layer, falconCompileInput, layer_id, device_idx=None
):
self.falcon_mlir_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
)
self.falcon_vmfb_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
)
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
print(f"[DEBUG] Trying to download vmfb from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
+ str(self.falcon_vmfb_path),
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=device_idx,
)
if vmfb is not None:
return vmfb, device_idx
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
+ str(self.falcon_mlir_path),
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
if layer_id == "word_embeddings":
f16_input_mask = [False]
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
else:
raise ValueError("Unsupported layer: ", layer_id)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
layer,
falconCompileInput,
is_f16=True,
f16_input_mask=f16_input_mask,
mlir_type="torchscript",
is_gptq=True,
)
del layer
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
falconCompileInput,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
device_idx=device_idx,
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module, device_idx
def compile(self, compressed=False):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
)
# Determine number of available devices
num_devices = 1
if self.device == "rocm":
import iree.runtime as ireert
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
shark_lm_head, _ = self.compile_layer(
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
word_embedding = WordEmbeddingsLayer(
self.src_model.transformer.word_embeddings
)
print("Compiling Layer word_embeddings")
shark_word_embedding, _ = self.compile_layer(
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
)
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
print("Compiling Layer ln_f")
shark_ln_f, _ = self.compile_layer(
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
shark_layers = []
for i in range(
int(len(self.src_model.transformer.h) / num_group_layers)
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
)
shark_module, device_idx = self.compile_layer(
pytorch_layer_i,
[sample_hidden_states, sample_attention_mask],
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
)
shark_layers.append(shark_layer_i)
sharded_model = ShardedFalconModel(
self.src_model,
shark_layers,
shark_word_embedding,
shark_ln_f,
shark_lm_head,
)
return sharded_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
print(f"{all_text}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = self.shark_model.forward(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
)
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
class UnshardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = "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",
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
@@ -123,7 +676,7 @@ class Falcon(SharkLLMBase):
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
@@ -381,7 +934,11 @@ class Falcon(SharkLLMBase):
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
@@ -408,6 +965,13 @@ class Falcon(SharkLLMBase):
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()
@@ -519,16 +1083,22 @@ if __name__ == "__main__":
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
import gc
if not args.sharded:
falcon = UnshardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
else:
falcon = ShardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
)
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True

View File

@@ -0,0 +1,91 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch
llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
}
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def chat(self, prompt):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -0,0 +1,14 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,428 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
from ui.chat import chat_element
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)

View File

View File

@@ -0,0 +1,517 @@
import gradio as gr
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
language_model = None
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def get_default_config():
return False
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global language_model
if language_model is None:
language_model = LanguageModel(
model, device=device, precision=precision
)
language_model.chat(prompt_prefix)
return "", ""
global past_key_values
global model_vmfb_key
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -53,6 +53,7 @@ datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += collect_data_files("einops")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),

View File

@@ -8,6 +8,7 @@ import traceback
import subprocess
import sys
import os
import requests
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
get_civitai_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
# needed to ensure webui sets the correct model name metadata
args.ckpt_loc = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":

View File

@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
from apps.stable_diffusion.src.utils import (
resamplers,
resampler_list,
)
class Image2ImagePipeline(StableDiffusionPipeline):
@@ -91,26 +95,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# TODO: process with variable HxW combos
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)

View File

@@ -41,3 +41,8 @@ from apps.stable_diffusion.src.utils.utils import (
resize_stencil,
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
from apps.stable_diffusion.src.utils.resamplers import (
resamplers,
resampler_list,
)

View File

@@ -0,0 +1,42 @@
import re
import requests
from apps.stable_diffusion.src.utils.stable_args import args
from pathlib import Path
from tqdm import tqdm
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = (
Path.cwd() / (args.ckpt_dir or "models") / base_filename
)
# we don't have this model downloaded yet
if not destination_path.is_file():
print(
f"downloading civitai model from {url} to {destination_path}"
)
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -0,0 +1,12 @@
import PIL.Image as Image
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()

View File

@@ -11,12 +11,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -28,7 +28,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
@@ -37,7 +37,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
if use_winograd:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
else:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)

View File

@@ -2,6 +2,8 @@ import argparse
import os
from pathlib import Path
from apps.stable_diffusion.src.utils.resamplers import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
@@ -168,17 +170,7 @@ p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
@@ -253,28 +245,30 @@ p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting.",
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting.",
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting.",
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting.",
help="If extend bottom for outpainting.",
)
p.add_argument(
@@ -306,7 +300,7 @@ p.add_argument(
p.add_argument(
"--import_mlir",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
@@ -329,7 +323,7 @@ p.add_argument(
p.add_argument(
"--use_tuned",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
@@ -422,7 +416,7 @@ p.add_argument(
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
@@ -641,6 +635,18 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
@@ -725,6 +731,18 @@ p.add_argument(
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(

View File

@@ -1,2 +1,3 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector

View File

@@ -4,6 +4,7 @@ import torch
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
stencil = {}
@@ -117,6 +118,9 @@ def controlnet_hint_conversion(
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
controlnet_hint = hint_zoedepth(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
@@ -127,7 +131,7 @@ def controlnet_hint_conversion(
stencil_to_model_id_map = {
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
@@ -184,3 +188,16 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
return detected_map
# Stencil 4. Depth (Only Zoe Preprocessing)
def hint_zoedepth(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
if not "depth" in stencil:
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -0,0 +1,58 @@
import numpy as np
import torch
from pathlib import Path
import requests
from einops import rearrange
remote_model_path = (
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
)
class ZoeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
modelpath = ckpt_path / "ZoeD_M12_N.pt"
with requests.get(remote_model_path, stream=True) as r:
r.raise_for_status()
with open(modelpath, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
model = torch.hub.load(
"monorimet/ZoeDepth:torch_update",
"ZoeD_N",
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
model.eval()
self.model = model
def __call__(self, input_image):
assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float()
image_depth = image_depth / 255.0
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
depth = self.model.infer(image_depth)
depth = depth[0, 0].cpu().numpy()
vmin = np.percentile(depth, 2)
vmax = np.percentile(depth, 85)
depth -= vmin
depth /= vmax - vmin
depth = 1.0 - depth
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
return depth_image

View File

@@ -477,7 +477,14 @@ def get_available_devices():
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(
f"{device_name} => {driver_name}://{i}"
)
return device_list
set_iree_runtime_flags()
@@ -804,11 +811,12 @@ def batch_seeds(
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
seed_random(str(seeds))
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
@@ -894,6 +902,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if args.use_hiresfix:
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
else:
png_size_text = f"{args.width}x{args.height}"
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}"
@@ -902,7 +917,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
@@ -929,8 +944,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"HEIGHT": args.height
if not args.use_hiresfix
else args.hiresfix_height,
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
@@ -968,6 +985,10 @@ def get_generation_text_info(seeds, device):
)
text_output += (
f"\nsize={args.height}x{args.width}, "
if not args.use_hiresfix
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
)
text_output += (
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"

View File

@@ -0,0 +1 @@
from apps.stable_diffusion.web.api.sdapi_v1 import sdapi

View File

@@ -0,0 +1,579 @@
import os
from collections import defaultdict
from enum import Enum
from fastapi import FastAPI
from pydantic import BaseModel, Field, conlist, model_validator
from apps.stable_diffusion.web.api.utils import (
frozen_args,
sampler_aliases,
encode_pil_to_base64,
decode_base64_to_image,
get_model_from_request,
get_scheduler_from_request,
get_lora_params,
get_device,
GenerationInputData,
GenerationResponseData,
)
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
get_custom_model_pathfile,
predefined_models,
predefined_paint_models,
predefined_upscaler_models,
scheduler_list,
)
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf
from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf
from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf
sdapi = FastAPI()
# Rest API: /sdapi/v1/sd-models (lists available models)
class AppParam(str, Enum):
txt2img = "txt2img"
img2img = "img2img"
inpaint = "inpaint"
outpaint = "outpaint"
upscaler = "upscaler"
@sdapi.get(
"/v1/sd-models",
summary="lists available models",
description=(
"This is all the models that this server currently knows about.\n "
"Models listed may still have a compilation and build pending that "
"will be triggered the first time they are used."
),
)
def sd_models_api(app: AppParam = frozen_args.app):
match app:
case "inpaint" | "outpaint":
checkpoint_type = "inpainting"
predefined = predefined_paint_models
case "upscaler":
checkpoint_type = "upscaler"
predefined = predefined_upscaler_models
case _:
checkpoint_type = ""
predefined = predefined_models
return [
{
"title": model_file,
"model_name": model_file,
"hash": None,
"sha256": None,
"filename": get_custom_model_pathfile(model_file),
"config": None,
}
for model_file in get_custom_model_files(
custom_checkpoint_type=checkpoint_type
)
] + [
{
"title": model,
"model_name": model,
"hash": None,
"sha256": None,
"filename": None,
"config": None,
}
for model in predefined
]
# Rest API: /sdapi/v1/samplers (lists schedulers)
@sdapi.get(
"/v1/samplers",
summary="lists available schedulers/samplers",
description=(
"These are all the Schedulers defined and available. Not "
"every scheduler is compatible with all apis. Aliases are "
"equivalent samplers in A1111 if they are known."
),
)
def sd_samplers_api():
reverse_sampler_aliases = defaultdict(list)
for key, value in sampler_aliases.items():
reverse_sampler_aliases[value].append(key)
return (
{
"name": scheduler,
"aliases": reverse_sampler_aliases.get(scheduler, []),
"options": {},
}
for scheduler in scheduler_list
)
# Rest API: /sdapi/v1/options (lists application level options)
@sdapi.get(
"/v1/options",
summary="lists current settings of application level options",
description=(
"A subset of the command line arguments set at startup renamed "
"to correspond to the A1111 naming. Only a small subset of A1111 "
"options are returned."
),
)
def options_api():
# This is mostly just enough to support what Koboldcpp wants, with a
# few other things that seemed obvious
return {
"samples_save": True,
"samples_format": frozen_args.output_img_format,
"sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc)
if frozen_args.ckpt_loc
else frozen_args.hf_model_id,
"sd_lora": frozen_args.use_lora,
"sd_vae": frozen_args.custom_vae or "Automatic",
"enable_pnginfo": frozen_args.write_metadata_to_png,
}
# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings)
@sdapi.get(
"/v1/cmd-flags",
summary="lists the command line arguments value that were set on startup.",
)
def cmd_flags_api():
return vars(frozen_args)
# Rest API: /sdapi/v1/txt2img (Text to image)
class ModelOverrideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
fallback_model="stabilityai/stable-diffusion-2-1-base"
)
class Txt2ImgInputData(GenerationInputData):
enable_hr: bool = frozen_args.use_hiresfix
hr_resize_y: int = Field(
default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8
)
hr_resize_x: int = Field(
default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8
)
override_settings: ModelOverrideSettings = None
@sdapi.post(
"/v1/txt2img",
summary="Does text to image generation",
response_model=GenerationResponseData,
)
def txt2img_api(InputData: Txt2ImgInputData):
model_id = get_model_from_request(
InputData,
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
)
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed},"
f"Model: {model_id}, "
f"Scheduler: {scheduler}. "
)
res = txt2img_inf(
InputData.prompt,
InputData.negative_prompt,
InputData.height,
InputData.width,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
hiresfix_height=InputData.hr_resize_y,
hiresfix_width=InputData.hr_resize_x,
hiresfix_strength=frozen_args.hiresfix_strength,
resample_type=frozen_args.resample_type,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/img2img (Image to image)
class StencilParam(str, Enum):
canny = "canny"
openpose = "openpose"
scribble = "scribble"
zoedepth = "zoedepth"
class Img2ImgInputData(GenerationInputData):
init_images: conlist(str, min_length=1, max_length=2)
denoising_strength: float = frozen_args.strength
use_stencil: StencilParam = frozen_args.use_stencil
override_settings: ModelOverrideSettings = None
@model_validator(mode="after")
def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData":
if (
self.use_stencil == StencilParam.scribble
and len(self.init_images) < 2
):
raise ValueError(
"a second image must be supplied for the controlnet:scribble stencil"
)
return self
@sdapi.post(
"/v1/img2img",
summary="Does image to image generation",
response_model=GenerationResponseData,
)
def img2img_api(
InputData: Img2ImgInputData,
):
model_id = get_model_from_request(
InputData,
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(InputData, "img2img")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
mask_image = (
decode_base64_to_image(InputData.init_images[1])
if len(InputData.init_images) > 1
else None
)
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = img2img_inf(
InputData.prompt,
InputData.negative_prompt,
{"image": init_image, "mask": mask_image},
InputData.height,
InputData.width,
InputData.steps,
InputData.denoising_strength,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
use_stencil=InputData.use_stencil,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/inpaint (Inpainting)
class PaintModelOverideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
class InpaintInputData(GenerationInputData):
image: str = Field(description="Base64 encoded input image")
mask: str = Field(description="Base64 encoded mask image")
is_full_res: bool = False # Is this setting backwards in the UI?
full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4)
denoising_strength: float = frozen_args.strength
use_stencil: StencilParam = frozen_args.use_stencil
override_settings: PaintModelOverideSettings = None
@sdapi.post(
"/v1/inpaint",
summary="Does inpainting generation on an image",
response_model=GenerationResponseData,
)
def inpaint_api(
InputData: InpaintInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "inpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.image)
mask = decode_base64_to_image(InputData.mask)
print(
f"Prompt: {InputData.prompt}, "
f'Negative Prompt: {InputData.negative_prompt}", '
f'Seed: {InputData.seed}", '
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = inpaint_inf(
InputData.prompt,
InputData.negative_prompt,
{"image": init_image, "mask": mask},
InputData.height,
InputData.width,
InputData.is_full_res,
InputData.full_res_padding,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/outpaint (Outpainting)
class DirectionParam(str, Enum):
left = "left"
right = "right"
up = "up"
down = "down"
class OutpaintInputData(GenerationInputData):
init_images: list[str]
pixels: int = Field(
default=frozen_args.pixels, ge=8, le=256, multiple_of=8
)
mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64)
directions: set[DirectionParam] = [
direction
for direction in ["left", "right", "up", "down"]
if vars(frozen_args)[direction]
]
noise_q: float = frozen_args.noise_q
color_variation: float = frozen_args.color_variation
override_settings: PaintModelOverideSettings = None
@sdapi.post(
"/v1/outpaint",
summary="Does outpainting generation on an image",
response_model=GenerationResponseData,
)
def outpaint_api(
InputData: OutpaintInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "outpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = outpaint_inf(
InputData.prompt,
InputData.negative_prompt,
init_image,
InputData.pixels,
InputData.mask_blur,
InputData.directions,
InputData.noise_q,
InputData.color_variation,
InputData.height,
InputData.width,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/upscaler (Upscaling)
class UpscalerModelOverideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
checkpoint_type="upscaler",
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
class UpscalerInputData(GenerationInputData):
init_images: list[str] = Field(
description="Base64 encoded image to upscale"
)
noise_level: int = frozen_args.noise_level
override_settings: UpscalerModelOverideSettings = None
@sdapi.post(
"/v1/upscaler",
summary="Does image upscaling",
response_model=GenerationResponseData,
)
def upscaler_api(
InputData: UpscalerInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="upscaler",
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
scheduler = get_scheduler_from_request(InputData, "upscaler")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = upscaler_inf(
InputData.prompt,
InputData.negative_prompt,
init_image,
InputData.height,
InputData.width,
InputData.steps,
InputData.noise_level,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}

View File

@@ -0,0 +1,211 @@
import base64
import pickle
from argparse import Namespace
from fastapi.exceptions import HTTPException
from io import BytesIO
from PIL import Image
from pydantic import BaseModel, Field
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
get_custom_model_files,
predefined_models,
predefined_paint_models,
predefined_upscaler_models,
scheduler_list,
scheduler_list_cpu_only,
)
# Probably overly cautious, but try to ensure we only use the starting
# args in each api call, as the code does `args.<whatever> = <changed_value>`
# in lots of places and in testing, it seemed to me, these changes leaked
# into subsequent api calls.
# Roundtripping through pickle for deepcopy, there is probably a better way
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
# an attempt to map some of the A1111 sampler names to scheduler names
# https://github.com/huggingface/diffusers/issues/4167 is where the
# (not so obvious) ones come from
sampler_aliases = {
# a1111/onnx (these point to diffusers classes in A1111)
"pndm": "PNDM",
"heun": "HeunDiscrete",
"ddim": "DDIM",
"ddpm": "DDPM",
"euler": "EulerDiscrete",
"euler-ancestral": "EulerAncestralDiscrete",
"dpm": "DPMSolverMultistep",
# a1111/k_diffusion (the obvious ones)
"Euler a": "EulerAncestralDiscrete",
"Euler": "EulerDiscrete",
"LMS": "LMSDiscrete",
"Heun": "HeunDiscrete",
# a1111/k_diffusion (not so obvious)
"DPM++ 2M": "DPMSolverMultistep",
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
"DPM++ 2M SDE": "DPMSolverMultistep++",
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
"DPM2": "KDPM2Discrete",
"DPM2 a": "KDPM2AncestralDiscrete",
}
allowed_schedulers = {
"txt2img": {
"schedulers": scheduler_list,
"fallback": "SharkEulerDiscrete",
},
"txt2img_hires": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DEISMultistep",
},
"img2img": {
"schedulers": scheduler_list_cpu_only,
"fallback": "EulerDiscrete",
},
"inpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"outpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"upscaler": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
}
# base pydantic model for sd generation apis
class GenerationInputData(BaseModel):
prompt: str = ""
negative_prompt: str = ""
hf_model_id: str | None = None
height: int = Field(
default=frozen_args.height, ge=128, le=768, multiple_of=8
)
width: int = Field(
default=frozen_args.width, ge=128, le=768, multiple_of=8
)
sampler_name: str = frozen_args.scheduler
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
seed: int = frozen_args.seed
n_iter: int = Field(default=frozen_args.batch_count)
class GenerationResponseData(BaseModel):
images: list[str] = Field(description="Generated images, Base64 encoded")
properties: dict = {}
info: str
# image encoding/decoding
def encode_pil_to_base64(images: list[Image.Image]):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if frozen_args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
def decode_base64_to_image(encoding: str):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=400, detail="Invalid encoded image")
# get valid sd models/vaes/schedulers etc.
def get_predefined_models(custom_checkpoint_type: str):
match custom_checkpoint_type:
case "inpainting":
return predefined_paint_models
case "upscaler":
return predefined_upscaler_models
case _:
return predefined_models
def get_model_from_request(
request_data=None,
checkpoint_type: str = "",
fallback_model: str = "",
):
model = None
if request_data:
if request_data.hf_model_id:
model = request_data.hf_model_id
elif request_data.override_settings:
model = request_data.override_settings.sd_model_checkpoint
# if the request didn't specify a model try the command line args
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
# make sure whatever we have is a valid model for the checkpoint type
if result in get_custom_model_files(
custom_checkpoint_type=checkpoint_type
) + get_predefined_models(checkpoint_type):
return result
# if not return what was specified as the fallback
else:
return fallback_model
def get_scheduler_from_request(
request_data: GenerationInputData, operation: str
):
allowed = allowed_schedulers[operation]
requested = request_data.sampler_name
requested = sampler_aliases.get(requested, requested)
return (
requested
if requested in allowed["schedulers"]
else allowed["fallback"]
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched
return next(
(device for device in available_devices if device_str in device),
available_devices[0],
)

View File

@@ -1,7 +1,8 @@
from multiprocessing import Process, freeze_support
from multiprocessing import freeze_support
import os
import sys
import logging
import apps.stable_diffusion.web.utils.app as app
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -21,26 +22,6 @@ if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
@@ -48,39 +29,47 @@ if __name__ == "__main__":
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
llm_chat_api,
)
from apps.stable_diffusion.web.api import sdapi
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# init global sd pipeline and config
global_obj._init()
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
api = FastAPI()
api.mount("/sdapi/", sdapi)
# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
api.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
api.add_api_route("/completions", llm_chat_api, methods=["post"])
api.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
api.include_router(APIRouter())
# deal with CORS requests if CORS accept origins are set
if args.api_accept_origin:
print(
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
)
api.add_middleware(
CORSMiddleware,
allow_origins=args.api_accept_origin,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
else:
print("API not configured for CORS")
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
@@ -94,7 +83,10 @@ if __name__ == "__main__":
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
from apps.stable_diffusion.web.ui.utils import (
create_custom_models_folders,
nodicon_loc,
)
create_custom_models_folders()
@@ -110,7 +102,6 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -122,7 +113,6 @@ if __name__ == "__main__":
# h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -131,7 +121,6 @@ if __name__ == "__main__":
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -140,7 +129,6 @@ if __name__ == "__main__":
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -149,7 +137,6 @@ if __name__ == "__main__":
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
@@ -213,7 +200,7 @@ if __name__ == "__main__":
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
@@ -267,6 +254,15 @@ if __name__ == "__main__":
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:
sd_web.load(
fn=lambda: gr.Info(
f"Port {args.server_port} is in use by another application. "
f"Shark is running on port {actual_port} instead."
)
)
# send to buttons
register_button_click(
txt2img_sendto_img2img,
@@ -399,42 +395,38 @@ if __name__ == "__main__":
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
[txt2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
[img2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
[inpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
[outpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
[upscaler_custom_model, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",
inbrowser=not app.launch(actual_port),
server_name="0.0.0.0",
server_port=args.server_port,
server_port=actual_port,
favicon_path=nodicon_loc,
)

View File

@@ -1,9 +1,7 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -14,10 +12,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -27,10 +23,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -40,10 +34,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -53,10 +45,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,

View File

@@ -5,9 +5,6 @@ import gradio as gr
import PIL
from math import ceil
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -30,6 +27,7 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -55,8 +53,7 @@ def img2img_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -103,21 +100,17 @@ def img2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -282,88 +275,6 @@ def img2img_inf(
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Img2Img Rest API.
def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["denoising_strength"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
use_stencil=InputData["use_stencil"]
if "use_stencil" in InputData.keys()
else "None",
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
resample_type="Lanczos",
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -382,32 +293,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
i2i_model_info = (
f"Custom Model Path: {str(get_custom_model_path())}"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models",
info=i2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
choices=get_custom_model_files() + predefined_models,
allow_custom_value=True,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
scale=2,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
@@ -423,6 +321,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -453,8 +352,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
allow_custom_value=True,
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
],
)
def show_canvas(choice):
@@ -583,17 +487,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
@@ -656,16 +550,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -677,13 +561,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{i2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -709,7 +606,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_size,
scheduler,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -4,9 +4,6 @@ import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -53,8 +50,7 @@ def inpaint_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -89,21 +85,17 @@ def inpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -228,86 +220,6 @@ def inpaint_inf(
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
res = inpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
{"image": init_image, "mask": mask},
InputData["height"],
InputData["width"],
InputData["is_full_res"],
InputData["full_res_padding"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -327,35 +239,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models",
info=inpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
scale=2,
)
# janky fix for overflowing text
inpaint_vae_info = (
@@ -371,6 +269,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -533,16 +432,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -554,14 +443,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{inpaint_model_info}\n"
"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -588,7 +489,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_size,
scheduler,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -53,8 +53,7 @@ def outpaint_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -88,21 +87,17 @@ def outpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -233,88 +228,6 @@ def outpaint_inf(
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["pixels"],
InputData["mask_blur"],
InputData["directions"],
InputData["noise_q"],
InputData["color_variation"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -332,37 +245,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models",
info=outpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
scale=2,
)
# janky fix for overflowing text
outpaint_vae_info = (
@@ -378,8 +276,8 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
@@ -561,16 +459,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -582,13 +470,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{outpaint_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -616,7 +517,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_size,
scheduler,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -32,36 +32,39 @@ model_map = {
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
@@ -77,7 +80,10 @@ def create_prompt(model_name, history, prompt_prefix):
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
@@ -126,6 +132,27 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -145,21 +172,8 @@ def chat(
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -210,8 +224,15 @@ def chat(
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
print(f"Will use target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
if args.iree_rocm_target_chip != "":
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
@@ -310,17 +331,7 @@ def llm_chat_api(InputData: dict):
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
vicuna_model = UnshardedVicuna(
model_name,

View File

@@ -5,15 +5,13 @@ import sys
import gradio as gr
from PIL import Image
from math import ceil
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
@@ -32,6 +30,7 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -52,8 +51,7 @@ def txt2img_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -91,21 +89,17 @@ def txt2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -145,6 +139,11 @@ def txt2img_inf(
args.max_length = max_length
args.height = height
args.width = width
args.use_hiresfix = use_hiresfix
args.hiresfix_height = hiresfix_height
args.hiresfix_width = hiresfix_width
args.hiresfix_strength = hiresfix_strength
args.resample_type = resample_type
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
@@ -301,75 +300,6 @@ def txt2img_inf(
return generated_imgs, text_output, ""
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Text2Img Rest API.
def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
use_hiresfix=False,
hiresfix_height=512,
hiresfix_width=512,
hiresfix_strength=0.6,
resample_type="Nearest Neighbor",
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -389,33 +319,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_custom_model = gr.Dropdown(
label=f"Models",
info=t2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
choices=get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
scale=2,
)
# janky fix for overflowing text
t2i_vae_info = (
@@ -432,6 +347,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
@@ -551,50 +467,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Low VRAM",
interactive=True,
)
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -618,6 +490,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
@@ -649,7 +556,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{t2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
@@ -692,7 +600,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_size,
scheduler,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
@@ -742,7 +649,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
@@ -758,9 +664,28 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
choices=scheduler_list,
value="SharkEulerDiscrete",
)
use_hiresfix.change(
fn=set_compatible_schedulers,
inputs=[use_hiresfix],
outputs=[scheduler],
queue=False,
)

View File

@@ -3,9 +3,6 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -46,8 +43,7 @@ def upscaler_inf(
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -85,21 +81,17 @@ def upscaler_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -252,83 +244,6 @@ def upscaler_inf(
yield generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Upscaler Rest API.
def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["noise_level"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -346,37 +261,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models",
info=upscaler_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
allow_custom_value=True,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
scale=2,
)
# janky fix for overflowing text
upscaler_vae_info = (
@@ -392,6 +292,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -553,16 +454,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -574,14 +465,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{upscaler_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -605,7 +508,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_size,
scheduler,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,

View File

@@ -170,4 +170,5 @@ def cancel_sd():
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()

View File

@@ -0,0 +1,105 @@
import os
import sys
import webview
import webview.util
import socket
from contextlib import closing
from multiprocessing import Process
from apps.stable_diffusion.src import args
def webview2_installed():
if sys.platform != "win32":
return False
# On windows we want to ensure we have MS webview2 available so we don't fall back
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
# apparently causes SHARK not to load in properly.
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
import winreg
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
# only way can find if a registry entry even exists is to try and open it
try:
# check for an all user install
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")
# if it didn't exist, we want to continue on...
except WindowsError:
try:
# ...to check for a current user install
with winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")
except WindowsError:
value = None
finally:
return (value is not None) and value != "" and value != "0.0.0.0"
def window(address):
from tkinter import Tk
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
def usable_port():
# Make sure we can actually use the port given in args.server_port. If
# not ask the OS for a port and return that as our port to use.
port = args.server_port
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
try:
sock.bind(("0.0.0.0", port))
except OSError:
with closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.bind(("0.0.0.0", 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock.getsockname()[1]
return port
def launch(port):
# setup to launch as an app if app mode has been requested and we're able
# to do it, answering whether we succeeded.
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
try:
t = Process(target=window, args=[f"http://localhost:{port}"])
t.start()
return True
except webview.util.WebViewException:
return False
else:
return False

View File

@@ -149,7 +149,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
@@ -175,10 +174,8 @@ def import_png_metadata(
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
@@ -217,7 +214,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,

View File

@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"gpu",
"cuda",
marks=pytest.mark.skipif(
check_device_drivers("gpu"), reason="nvidia-smi not found"
check_device_drivers("cuda"), reason="nvidia-smi not found"
),
),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",

140
docs/shark_sd_koboldcpp.md Normal file
View File

@@ -0,0 +1,140 @@
# Overview
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
## In Action
![preview](https://user-images.githubusercontent.com/121311569/280557602-bb97bad0-fdf5-4922-a2cc-4f327f2760db.jpg)
## Memory considerations
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
## Performance Considerations
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
## Setup SHARK and prerequisites:
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:
![SHARK API startup](https://user-images.githubusercontent.com/121311569/280556294-c3f7fc1a-c8e2-467d-afe6-365638d6823a.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
## Configure Koboldcpp for local image generation:
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
![Settings button location](https://user-images.githubusercontent.com/121311569/280556246-10692d79-e89f-4fdf-87ba-82f3d78ed49d.png)
![Advanced Settings with 'Local A1111' location](https://user-images.githubusercontent.com/121311569/280556234-6ebc8ba7-1469-442a-93a7-5626a094ddf1.png)
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
![Add Image Button](https://user-images.githubusercontent.com/121311569/280556161-846c7883-4a83-4458-a56a-bd9f93ca354c.png)
...or by selecting the 'Autogenerate' option in the settings:
![Setting the autogenerate images option](https://user-images.githubusercontent.com/121311569/280556230-ae221a46-ba68-499b-a519-c8f290bbbeae.png)
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556694-55cd1c55-a059-4b54-9293-63d66a32368e.png)
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
![Entering extra image styles](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
## Connecting to SHARK on a different address or port
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
![Can't find the A1111 endpoint error](https://user-images.githubusercontent.com/121311569/280555857-601f53dc-35e9-4027-9180-baa61d2393ba.png)
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
* Select the cog icon the Generate Images section of Advanced settings:
![Selecting the endpoint cog](https://user-images.githubusercontent.com/121311569/280555866-4287ecc5-f29f-4c03-8f5a-abeaf31b0442.png)
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
![Changing the endpoint port](https://user-images.githubusercontent.com/121311569/280556170-f8848b7b-6fc9-4cf7-80eb-5c312f332fd9.png)
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
![Changing the endpoint hostname](https://user-images.githubusercontent.com/121311569/280556167-c6541dea-0f85-417a-b661-fdf4dc40d05f.png)
## Examples
Here's how Koboldcpp shows an image being requested:
![An image being generated]((https://user-images.githubusercontent.com/121311569/280556210-bb1c9efd-79ac-478e-b726-b25b82ef2186.png)
The generated image in context in story mode:
![A generated image](https://user-images.githubusercontent.com/121311569/280556179-4e9f3752-f349-4cba-bc6a-f85f8dc79b10.jpg)
And the same image when clicked on:
![A selected image](https://user-images.githubusercontent.com/121311569/280556216-2ca4c0a4-3889-4ef5-8a09-30084fb34081.jpg)
## Where to find the images in SHARK
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
![SHARK web ui output gallery tab](https://user-images.githubusercontent.com/121311569/280556582-9303ca85-2594-4a8c-97a2-fbd72337980b.jpg)
...or by browsing to the `output_dir` in your operating system's file manager:
![SHARK output directory subfolder in Windows File Explorer](https://user-images.githubusercontent.com/121311569/280556297-66173030-2324-415c-a236-ef3fcd73e6ed.jpg)

View File

@@ -17,6 +17,7 @@ pytest-forked
Pillow
parameterized
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
@@ -41,10 +42,12 @@ tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
einops # for zoedepth
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea

View File

@@ -4,7 +4,7 @@ import base64
from io import BytesIO
def upscaler_test():
def upscaler_test(verbose=False):
# Define values here
prompt = ""
negative_prompt = ""
@@ -44,10 +44,17 @@ def upscaler_test():
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
print(
f"[upscaler] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
def img2img_test():
def img2img_test(verbose=False):
# Define values here
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
@@ -87,7 +94,16 @@ def img2img_test():
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"response from server was : {res.status_code}")
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(
f"[img2img] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
# NOTE Uncomment below to save the picture
@@ -103,7 +119,7 @@ def img2img_test():
# response_img.save(r"rest_api_tests/response_img.png")
def inpainting_test():
def inpainting_test(verbose=False):
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
@@ -150,10 +166,17 @@ def inpainting_test():
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Inpainting] response from server was : {res.status_code}")
print(
f"[inpaint] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
def outpainting_test():
def outpainting_test(verbose=False):
prompt = "Paint a rabbit riding on the dog"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
@@ -200,10 +223,17 @@ def outpainting_test():
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[Outpaint] response from server was : {res.status_code}")
print(
f"[outpaint] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
def txt2img_test():
def txt2img_test(verbose=False):
prompt = "Paint a rabbit in a top hate"
negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft"
seed = 2121991605
@@ -232,12 +262,119 @@ def txt2img_test():
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
print(f"[txt2img] response from server was : {res.status_code}")
print(
f"[txt2img] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(
f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n"
)
def sd_models_test(verbose=False):
url = "http://127.0.0.1:8080/sdapi/v1/sd-models"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[sd_models] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
def sd_samplers_test(verbose=False):
url = "http://127.0.0.1:8080/sdapi/v1/samplers"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[sd_samplers] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
def options_test(verbose=False):
url = "http://127.0.0.1:8080/sdapi/v1/options"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[options] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
def cmd_flags_test(verbose=False):
url = "http://127.0.0.1:8080/sdapi/v1/cmd-flags"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
res = requests.get(url=url, headers=headers, timeout=1000)
print(
f"[cmd-flags] response from server was : {res.status_code} {res.reason}"
)
if verbose or res.status_code != 200:
print(f"\n{res.json() if res.status_code == 200 else res.content}\n")
if __name__ == "__main__":
txt2img_test()
img2img_test()
upscaler_test()
inpainting_test()
outpainting_test()
import argparse
parser = argparse.ArgumentParser(
description=(
"Exercises the Stable Diffusion REST API of Shark. Make sure "
"Shark is running in API mode on 127.0.0.1:8080 before running"
"this script."
),
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help=(
"also display selected info from the JSON response for "
"successful requests"
),
)
args = parser.parse_args()
sd_models_test(args.verbose)
sd_samplers_test(args.verbose)
options_test(args.verbose)
cmd_flags_test(args.verbose)
txt2img_test(args.verbose)
img2img_test(args.verbose)
upscaler_test(args.verbose)
inpainting_test(args.verbose)
outpainting_test(args.verbose)

View File

@@ -177,7 +177,7 @@ def compile_through_fx(model, inputs, mlir_loc=None):
mlir_model = str(module)
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=False
)
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="mhlo")
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
shark_module.compile()
output_idx = 0
data_idx = 1

View File

@@ -6,7 +6,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
mlir_model, device="cpu", mlir_dialect="tm_tensor"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -13,9 +13,7 @@ arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
# Generate the random inputs and feed into the graph.
x = shark_module.generate_random_inputs()
@@ -23,15 +21,11 @@ shark_module.compile()
print(shark_module.forward(x))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
shark_module.compile()
print(shark_module.forward(x))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
)
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
shark_module.compile()
print(shark_module.forward(x))

View File

@@ -8,9 +8,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward(inputs)
print("The obtained result via shark is: ", result)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
print(golden_out)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
print("Obtained result", result)

View File

@@ -49,9 +49,7 @@ module = torch_mlir.compile(
mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
)
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
shark_module.compile()

View File

@@ -360,7 +360,7 @@ mlir_importer = SharkImporter(
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="vulkan", mlir_dialect="linalg"
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(input_dlrm)

View File

@@ -294,7 +294,7 @@ def test_dlrm() -> None:
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="cpu", mlir_dialect="linalg"
dlrm_mlir, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -33,7 +33,7 @@ mlir_importer = SharkImporter(
tracing_required=False
)
shark_module = SharkInference(vision_mlir, func_name, mlir_dialect="linalg")
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input,))
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)

View File

@@ -7,7 +7,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
)
shark_module = SharkInference(
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
mlir_model, device="vulkan", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)

View File

@@ -19,9 +19,12 @@ import sys
import subprocess
def run_cmd(cmd, debug=False):
def run_cmd(cmd, debug=False, raise_err=False):
"""
Inputs: cli command string.
Inputs:
cmd : cli command string.
debug : if True, prints debug info
raise_err : if True, raise exception to caller
"""
if debug:
print("IREE run command: \n\n")
@@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
if raise_err:
raise Exception from e
else:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
def iree_device_map(device):
@@ -95,38 +101,31 @@ _IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
"""
Checks necessary drivers present for gpu and vulkan devices
False => drivers present!
"""
if "://" in device:
device = device.split("://")[0]
if device == "cuda":
try:
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
return False
except Exception:
return True
elif device == "cpu":
return False
elif device == "rocm":
try:
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
except Exception:
return True
from iree.runtime import get_driver
device_mapped = iree_device_map(device)
try:
_ = get_driver(device_mapped)
except ValueError as ve:
print(
f"[ERR] device `{device}` not registered with IREE. "
"Ensure IREE is configured for use with this device.\n"
f"Full Error: \n {repr(ve)}"
)
return True
except RuntimeError as re:
print(
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
)
return True
# Unknown device. We assume drivers are installed.
return False
@@ -134,11 +133,32 @@ def check_device_drivers(device):
# Installation info for the missing device drivers.
def device_driver_info(device):
if device == "cuda":
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
elif device in ["metal", "vulkan"]:
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
elif device == "rocm":
return "rocm info not found. Please install rocm"
device_driver_err_map = {
"cuda": {
"debug": "Try `nvidia-smi` on system to check.",
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
},
"vulkan": {
"debug": "Try `vulkaninfo` on system to check.",
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
},
"metal": {
"debug": "Check if Bare metal is supported and enabled on your system.",
"solution": ".",
},
"rocm": {
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
},
}
if device in device_driver_err_map:
err_msg = (
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
)
return err_msg
else:
return f"{device} is not supported."

View File

@@ -16,7 +16,6 @@ import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
@@ -34,16 +33,24 @@ def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan now."
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
device_num = device_uri[1]
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
if device_uri[0] == "cpu":
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
@@ -55,6 +62,7 @@ def get_iree_device_args(device, extra_args=[]):
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
@@ -73,7 +81,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
@@ -311,6 +319,8 @@ def compile_module_to_flatbuffer(
input_type = "tosa"
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
elif frontend in ["torch", "pytorch"]:
input_type = "torch"
if compile_str:
flatbuffer_blob = ireec.compile_str(
@@ -322,7 +332,7 @@ def compile_module_to_flatbuffer(
else:
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
module,
str(module),
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
@@ -331,8 +341,12 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device, device_idx=None):
def get_iree_module(
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
):
# Returns the compiled module and the configs.
for flag in rt_flags:
ireert.flags.parse_flag(flag)
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
@@ -354,9 +368,22 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
flatbuffer_blob_or_path,
device: str,
device_idx: int = None,
rt_flags: list = [],
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
rt_flags.append(flag)
for flag in rt_flags:
print(flag)
ireert.flags.parse_flags(flag)
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
@@ -373,6 +400,9 @@ def load_vmfb_using_mmap(
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
@@ -383,6 +413,7 @@ def load_vmfb_using_mmap(
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
@@ -402,6 +433,8 @@ def load_vmfb_using_mmap(
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
for flag in shark_args.additional_runtime_args:
ireert.flags.parse_flags(flag)
dl.log(f"ireert.SystemContext created")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
@@ -428,6 +461,7 @@ def get_iree_compiled_module(
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
rt_flags: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
@@ -435,13 +469,13 @@ def get_iree_compiled_module(
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module,
device,
frontend,
model_config_path,
extra_args,
debug,
compile_str,
module=module,
device=device,
frontend=frontend,
model_config_path=model_config_path,
extra_args=extra_args,
debug=debug,
compile_str=compile_str,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -450,11 +484,14 @@ def get_iree_compiled_module(
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
flatbuffer_blob, device, device_idx, rt_flags
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
)
ret_params = {
"vmfb": vmfb,
@@ -469,17 +506,21 @@ def load_flatbuffer(
device: str,
device_idx: int = None,
mmap: bool = False,
rt_flags: list = [],
):
temp_file_to_unlink = None
if mmap:
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
flatbuffer_path, device, device_idx, rt_flags
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
)
ret_params = {
"vmfb": vmfb,
@@ -502,13 +543,13 @@ def export_iree_module_to_vmfb(
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module,
device,
mlir_dialect,
model_config_path,
extra_args,
debug,
compile_str,
module=module,
device=device,
frontend=mlir_dialect,
model_config_path=model_config_path,
extra_args=extra_args,
debug=debug,
compile_str=compile_str,
)
if module_name is None:
device_name = (
@@ -544,10 +585,17 @@ def get_results(
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
device: str = None,
):
"""Runs a .vmfb file given inputs and config and returns output."""
with DetailLogger(debug_timeout) as dl:
device_inputs = []
if device == "rocm" and hasattr(config, "id"):
haldriver = ireert.get_driver("rocm")
haldevice = haldriver.create_device(
config.id,
allocators=shark_args.device_allocator,
)
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
@@ -584,7 +632,7 @@ def get_results(
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if device == "metal" and shark_args.device_allocator == "caching":
if "metal" in device and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
@@ -592,7 +640,9 @@ def get_iree_runtime_config(device):
haldevice = haldriver.create_device_by_uri(
device,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator if device != "metal" else None,
allocators=shark_args.device_allocator
if "metal" not in device
else None,
)
config = ireert.Config(device=haldevice)
return config

View File

@@ -18,7 +18,11 @@ import functools
import iree.runtime as ireert
import ctypes
import sys
from subprocess import CalledProcessError
from shark.parser import shark_args
from shark.iree_utils._common import run_cmd
# TODO: refactor to rocm and cuda utils
# Get the default gpu args given the architecture.
@@ -39,56 +43,85 @@ def get_iree_gpu_args():
return []
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from hipinfo.
import os
import re
import subprocess
def check_rocm_device_arch_in_args(extra_args):
# Check if the target arch flag for rocm device present in extra_args
for flag in extra_args:
if "iree-rocm-target-chip" in flag:
flag_arch = flag.split("=")[1]
return flag_arch
return None
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
def get_rocm_device_arch(device_num=0, extra_args=[]):
# ROCM Device Arch selection:
# 1 : User given device arch using `--iree-rocm-target-chip` flag
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
# 3 : default arch : gfx1100
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
if arch_in_flag is not None:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
)
rocm_arch = "gfx1100"
return arch_in_flag
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
f"--iree-rocm-bc-dir={bc_path}",
]
arch_in_device_dump = None
# get rocm arch from iree dump devices
def get_devices_info_from_dump(dump):
from os import linesep
dump_clean = list(
filter(
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
dump.split(linesep),
)
)
arch_pairs = [
(
dump_clean[i].split("=")[1].strip(),
dump_clean[i + 1].split(":")[1].strip(),
)
for i in range(0, len(dump_clean), 2)
]
return arch_pairs
dump_device_info = None
try:
dump_device_info = run_cmd(
"iree-run-module --dump_devices=rocm", raise_err=True
)
except Exception as e:
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
if arch_in_device_dump is not None:
print(f"Found ROCm device arch : {arch_in_device_dump}")
return arch_in_device_dump
default_rocm_arch = "gfx_1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
)
return default_rocm_arch
# Get the default gpu args given the architecture.
def get_iree_rocm_args(device_num=0, extra_args=[]):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
rocm_flags = ["--iree-rocm-link-bc=true"]
if check_rocm_device_arch_in_args(extra_args) is None:
rocm_arch = get_rocm_device_arch(device_num, extra_args)
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
return rocm_flags
# Some constants taken from cuda.h

View File

@@ -27,9 +27,12 @@ from shark.parser import shark_args
def get_all_vulkan_devices():
from iree.runtime import get_driver
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
try:
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
except:
device_list_src = {}
return [d["name"] for d in device_list_src]
@@ -68,6 +71,8 @@ def get_vulkan_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# TODO: Replace this with a dict or something smarter.
system_os = get_os_name()
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
@@ -117,6 +122,8 @@ def get_vulkan_target_triple(device_name):
# Amd Targets
# Linux: Radeon RX 7900 XTX
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7800")):
triple = f"rdna3-7800-{system_os}"
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("Radeon", "780M")):

View File

@@ -26,7 +26,7 @@ class SplitStrToListAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
del parser, option_string
setattr(namespace, self.dest, shlex.split(values[0]))
setattr(namespace, self.dest, shlex.split(" "))
parser = argparse.ArgumentParser(description="SHARK runner.")
@@ -44,6 +44,13 @@ parser.add_argument(
action=SplitStrToListAction,
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
)
parser.add_argument(
"--additional_runtime_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
)
parser.add_argument(
"--enable_tf32",
type=bool,

View File

@@ -73,6 +73,7 @@ class SharkInference:
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
rt_flags: list = [],
):
self.mlir_module = mlir_module
if mlir_module is not None:
@@ -100,6 +101,7 @@ class SharkInference:
self.shark_runner = None
self.mmap = mmap
self.rt_flags = rt_flags
def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
@@ -134,6 +136,7 @@ class SharkInference:
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
rt_flags=self.rt_flags,
)
if self.dispatch_benchmarks is not None:
@@ -147,11 +150,15 @@ class SharkInference:
# inputs are considered to be tuple of np.array.
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
return self.shark_runner.run(
function_name, inputs, send_to_host, device=self.device
)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
return self.shark_runner.run(
"forward", inputs, send_to_host, device=self.device
)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
@@ -220,12 +227,14 @@ class SharkInference:
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
rt_flags=self.rt_flags,
)
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
rt_flags=self.rt_flags,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]

View File

@@ -72,6 +72,7 @@ class SharkRunner:
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
rt_flags: list = [],
):
self.mlir_module = mlir_module
if self.mlir_module is not None:
@@ -86,6 +87,7 @@ class SharkRunner:
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.device_idx = device_idx
self.rt_flags = rt_flags
if check_device_drivers(self.device):
print(device_driver_info(self.device))
@@ -99,6 +101,7 @@ class SharkRunner:
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
rt_flags=self.rt_flags,
compile_str=self.compile_str,
)
self.iree_compilation_module = params["vmfb"]
@@ -106,7 +109,9 @@ class SharkRunner:
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
def run(
self, function_name, inputs: tuple, send_to_host=False, device=None
):
return get_results(
self.iree_compilation_module,
function_name,
@@ -114,6 +119,7 @@ class SharkRunner:
self.iree_config,
self.mlir_dialect,
send_to_host,
device=device,
)
# Get all function names defined within the compiled module.

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from shark.parser import shark_args

View File

@@ -1,30 +1,25 @@
import argparse
import os
import torch
import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
)
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60
from typing import Iterable
def create_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def create_module(model_name, tokenizer, device, args):
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, allow_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -33,8 +28,11 @@ def create_module(model_name, tokenizer, device):
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
if os.path.isfile(mlir_path):
print(f"Found .mlir from {mlir_path}")
else:
@@ -42,7 +40,7 @@ def create_module(model_name, tokenizer, device):
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
@@ -57,7 +55,7 @@ def create_module(model_name, tokenizer, device):
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
@@ -71,11 +69,11 @@ def shouldStop(tokens):
return False
def generate_new_token(shark_model, tokenizer, new_text):
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
model_inputs = tokenizer(
new_text,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
@@ -84,7 +82,7 @@ def generate_new_token(shark_model, tokenizer, new_text):
model_inputs["attention_mask"],
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = shark_model("forward", inputs)
output = shark_module("forward", inputs)
output = torch.FloatTensor(output[0])
next_toks = torch.topk(output, 1)
stop_generation = False
@@ -104,39 +102,96 @@ def generate_new_token(shark_model, tokenizer, new_text):
return ret_dict
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
def generate_tokens(
opt_shark_module: "SharkInference",
tokenizer,
input_text: str,
max_output_len: int,
print_intermediate_results: True,
) -> Iterable[str]:
words_list = []
new_text = input_text
try:
for _ in range(max_output_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text, max_output_len
)
detok = generated_token_op["detok"]
if generated_token_op["stop_generation"]:
break
if print_intermediate_results:
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text += detok
except KeyboardInterrupt as e:
print("Exiting token generation.")
return words_list
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"facebook/" + OPT_MODEL, use_fast=False
args = parse_args()
if "smoothquant" in args.model_name:
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
else:
token_model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
opt_shark_module.load_module(vmfb_path)
while True:
try:
new_text = input("Give me a sentence to complete:")
new_text_init = new_text
words_list = []
for i in range(MAX_NEW_TOKENS):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
except KeyboardInterrupt:
print("Exiting program.")
break
input_text = input("Give me a sentence to complete:")
generate_tokens(
opt_shark_module, tokenizer, input_text, args.max_seq_len
)

View File

@@ -0,0 +1,74 @@
import argparse
import os
import opt_causallm
import opt_util
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, OPTForCausalLM
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = opt_causallm.create_module(
args.model_name, tokenizer, "cpu-task", args
)
opt_shark_module.load_module(vmfb_path)
for prompt in opt_util.PROMPTS:
print("\n\nprompt: {}".format(prompt))
response = opt_causallm.generate_tokens(
opt_shark_module,
tokenizer,
prompt,
args.max_seq_len,
print_intermediate_results=False,
)
print("response: {}".format("".join(response)))

View File

@@ -19,12 +19,16 @@ import json
import os
import psutil
import time
import numpy as np
from typing import Tuple
from opt_util import PROMPTS
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
from shark.parser import shark_args
import iree.compiler as ireec
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
@@ -41,19 +45,6 @@ REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
@@ -63,14 +54,15 @@ def get_memory_info():
return process.memory_info()
def create_vmfb_module(
def import_mlir_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
@@ -87,6 +79,27 @@ def create_vmfb_module(
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
@@ -96,49 +109,49 @@ def create_vmfb_module(
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if has_mlir:
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
return_str=True,
if not has_mlir:
import_mlir_module(
model_name,
tokenizer,
device,
max_seq_len,
)
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
shark_module = SharkInference(
model_mlir,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
rt_flags=[],
)
vmfb_name = (
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model(
model_name: str, max_seq_len: int, recompile_shark: bool
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
if plugin_path is not None:
rt_flags = [f"--executable_plugin={plugin_path}"]
else:
rt_flags = []
shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
@@ -148,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def load_huggingface_model(model_name: str) -> ModelWrapper:
def load_huggingface_model(
model_name: str, token_model_name: str
) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
)
@@ -167,11 +182,14 @@ def save_json(data, filename):
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
model_name: str,
token_model_name: str,
max_seq_len: int,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model(model_name)
model_wrapper = load_huggingface_model(model_name, token_model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
@@ -215,13 +233,17 @@ def collect_huggingface_logits(
def collect_shark_logits(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
plugin_path: str,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
model_wrapper = load_shark_model(
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
@@ -302,6 +324,11 @@ def parse_args():
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
@@ -318,6 +345,18 @@ def parse_args():
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
parser.add_argument(
"--token-model-name",
help="HF ID to create tokenizer.",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
@@ -325,16 +364,28 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
if args.token_model_name == None:
if "smoothquant" in args.model_name:
args.token_model_name = (
f"facebook/opt-{args.model_name.split('-')[3]}"
)
else:
args.token_model_name = args.model_name
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.token_model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
args.plugin_path,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
args.model_name,
args.token_model_name,
args.max_seq_len,
args.save_json,
)
print("# Summary: {}".format(json.dumps(huggingface_report)))

View File

@@ -0,0 +1,12 @@
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from tank.test_utils import get_valid_test_params, shark_test_name_func

View File

@@ -44,7 +44,7 @@ class TapasBaseModuleTest(unittest.TestCase):
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("cuda"), reason=device_driver_info("gpu")
check_device_drivers("cuda"), reason=device_driver_info("cuda")
)
def test_module_static_cuda(self):
dynamic = False