Compare commits

..

230 Commits

Author SHA1 Message Date
Vivek Khandelwal
666e601dd9 Remove sharding support for non-180B falcon variants 2023-11-27 13:45:13 +05:30
Vivek Khandelwal
ca58908e5b Add Falcon-GPTQ Support for 2-way sharding 2023-11-27 13:45:13 +05:30
Jakub Kuderski
1f5b39f56e [vicuna.py] Add option to enable tracing (#1993)
This makes the program wait for tracy profiler to connect before exiting
and flush profiling data after each token.

I don't know how to select the tracy iree-runtime variant
programatically -- instead, print an error and exit.
2023-11-24 12:25:03 -08:00
Jakub Kuderski
2da31c4109 [vicuna.py] Rework benchmark statistics calculation (#1992)
- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
2023-11-23 12:04:03 -05:00
Ean Garvey
da50a16242 Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989) 2023-11-23 01:01:41 -06:00
Stefan Kapusniak
ce38d49f05 Add .mlir to startup shark_tmp cleanup (#1991)
* Add .mlir to the fiiles that are deleted from `./shark_tmp` when studio
is started.
* refactor/rename existing gradio temp file cleanup on startup to be
consistent with a general `./shark_tmp` cleanup
2023-11-22 14:34:28 -06:00
PhaneeshB
2f780f0d38 quick fix rocm None device 2023-11-22 21:17:25 +05:30
Ean Garvey
d051c3a4a7 Use clean_device_info() by default and don't write .mlir to /tmp/ (#1984)
* Move clean_device_info to compile_utils

* Update compile_utils.py

* Fix .mlir writes for some user-level permissions

* Fix cases where full URI is given

* Fix conditionals.

* Fix device path handling in vulkan utils.
2023-11-20 13:10:31 -06:00
Ean Garvey
1b11c82c9d Small UI tweaks for chatbot, fix torchvision requirements (#1988)
- add torchvision to setup_venv.ps1 -- we need this for the torchvision::nms that is now a dependency of controlnet features.
- Don't have bad flashy orange updates when using the chatbot
- Don't limit the height of the chatbot -- there's mixed opinions and solutions around this one. I think the default (400) is just way too small and LLMs generate plenty enough to justify matching the output.
2023-11-21 00:09:10 +05:30
gpetters94
80a33d427f Save intermediate values of controlnet (#1981) 2023-11-17 19:05:41 -05:00
Stefan Kapusniak
4125a26294 API/Docs: Fix incorrect cors arguments listing (#1983)
* Replace `api_cors_origin` in the api/koboldcpp doc, with the correct
 `api_accept_origin`
2023-11-17 12:29:01 -06:00
Ean Garvey
905d0103ff Revert "Re-enable SD tunings without matmuls. (#1976)" (#1979)
This reverts commit 70817bb50a.
2023-11-17 23:44:33 +05:30
Stefan Kapusniak
192b3b2c61 UI: Output galllery cleanups (#1959)
* Workaround gradio bug that causes the parameters frame to always show
scrollbars.
* Remove the original funky method of setting the number of image
columns in the gallery using _fn= javacript events. The version
of gradio we now have pinned allows doing this by setting the property
on the gallery directly and also doesn't keep resetting the columns on
other events being fired.
2023-11-15 22:20:42 -06:00
Stefan Kapusniak
8f9adc4a2a UI: Display top tag frequencies for selected LoRA (#1972)
* Adds a function to webui utils to read metadata from
.safetensors LoRA files. and do limiting parsing of the format written
out by the Kohya SS scripts (https://github.com/kohya-ss/sd-scripts)
to get tag frequency and trained model information.
* Adds a new common_ui_events.py file for gradio event handlers
needed for multiple UI tabs, and adds an event handler for binding to
the change event of the LoRA selection boxes, that outputs HTML
to display the LoRA tag frequency and model information.
* Adds an HTML gradio control to each of the SD tabs to show the
LoRA model name, and most frequently trained tags.
* Bind the change event of the LoRA selection box on each tab
to our new event handler, with the output set to the relevant HTML
control.
2023-11-15 22:19:54 -06:00
Ean Garvey
70817bb50a Re-enable SD tunings without matmuls. (#1976) 2023-11-15 20:42:53 -06:00
jinchen62
dd37c26d36 Update brevitas quant api (#1975) 2023-11-15 10:04:07 -08:00
PhaneeshB
a708879c6c fix iree version mismatch 2023-11-15 01:24:42 +05:30
Ean Garvey
bb1b49eb6f Add --no-index to setup_venv.sh runtime pip install. 2023-11-14 21:44:20 +05:30
Ean Garvey
f6d41affd9 (SHARK Studio) Add Turbine-based llm chatbot. (#1933)
* Dan shark studio (#1970)

* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>

* Fix formatting.

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-14 09:56:28 -06:00
Stefan Kapusniak
c2163488d8 SD/UI Restrict hires fix/img2img resamplers/schedulers (#1955)
* Restrict resamplers for img2img and high res fix to the ones that
PIL.Image actually supports, since it uses that to di the resampling.
Removed: Antialias, Affine, Cubic. Added: Hamming.
* Set list of available schedulers to CPU only when high res fix
is selected in the web ui. Set list to all schdulers when high res fix
is deselected.
* Put hi res fix in its own Accordian in the txt2img UI instead of
grouping it with Advanced Options.
2023-11-13 16:08:24 -06:00
PhaneeshB
54bff4611d fix cli rocm device selection 2023-11-13 23:35:55 +05:30
PhaneeshB
11510d5111 add intra rocm vmfb differentiator 2023-11-13 23:35:55 +05:30
PhaneeshB
32cab73a29 add iree-rocm-target-chip only if added by user 2023-11-13 23:35:55 +05:30
PhaneeshB
392bade0bf enable non default rocm device selection for webui 2023-11-13 23:35:55 +05:30
Stefan Kapusniak
91df5f0613 API/Docs: Fix an image link in koboldcpp doc (#1954)
* Fix the image link for the koboldcpp style button pointing to the
dialog image rather than the button image.
2023-11-13 11:14:29 -06:00
dependabot[bot]
df20cf9c8a Bump langchain in /apps/language_models/langchain (#1968)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.325 to 0.0.329.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.325...v0.0.329)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-12 19:46:00 -08:00
Ean Garvey
c4a908c3ea Pin pydantic to 2.4.1 in requirements (#1967)
pyinstaller-hooks-contrib doesn't see beta versions of pydantic as versions greater than 2.0.0, and so it looks for an attribute `compile` only available in versions older than 2.0.0 if you have a beta version of pydantic.
2023-11-10 21:34:52 -06:00
Stefan Kapusniak
6285430d8a UI: Fix webui launch on non-Windows (#1963)
* Moves the imports of winreg and Tk, into the functions that use them,
with winreg behind a guard clause. This should hopefully mean that if
you're not on Window or not using `ui=app` we won't trip over either
of these due to them not being there.
2023-11-10 16:38:32 -06:00
PhaneeshB
51afe19e20 fix rocm arch selection 2023-11-10 13:22:51 +05:30
Ean Garvey
31005bcf73 Don't require vulkan installation to query devices. (#1953) 2023-11-09 14:46:44 -06:00
dependabot[bot]
f41ad87ef6 Bump langchain in /apps/language_models/langchain (#1926)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.202 to 0.0.325.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.202...v0.0.325)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:03:47 -06:00
dependabot[bot]
d811524a00 Bump pypdf from 3.12.2 to 3.17.0 in /apps/language_models/langchain (#1929)
Bumps [pypdf](https://github.com/py-pdf/pypdf) from 3.12.2 to 3.17.0.
- [Release notes](https://github.com/py-pdf/pypdf/releases)
- [Changelog](https://github.com/py-pdf/pypdf/blob/main/CHANGELOG.md)
- [Commits](https://github.com/py-pdf/pypdf/compare/3.12.2...3.17.0)

---
updated-dependencies:
- dependency-name: pypdf
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:02:43 -06:00
Sungsoon Cho
51e1bd1c5d (OPT) Fix typo in the message; s/reponse/response (#1920) 2023-11-09 11:00:48 -06:00
Phaneesh Barwaria
db89b1bdc1 Fix MacOS web execution flow (#1899)
* fix metal device path for chatbot

* single device remove indexing

* lint fix
2023-11-09 10:59:29 -06:00
Huang Qi
2754e2e257 Fix wrong parameter index passed to 'compile_module_to_flatbuffer' (#1921)
compile_str is always False in compile_module_to_flatbuffer since there
is a parameter 'model_name' before 'debug'.

This issue is relative to https://github.com/nod-ai/SHARK/pull/1863.

Then we can use mlir model buffer in RAM to run inference.
2023-11-09 10:58:05 -06:00
PhaneeshB
ab0e870c43 fix vicuna cli vulkan 2023-11-09 22:27:13 +05:30
Stefan Kapusniak
fb30e8c226 UI: Fix some webui launch corner cases (#1952)
* On windows insist on the presence of webview2 as the embeddable
browser for `ui=app`. If we can't find it, effectively switch back to
`ui=web`. This should prevent pywebview trying to use MSHTML, whilst
saying its deprecated, and apparently we are too much for poor old IE11
* Add webview2 runtime droppings to .gitignore.
* If we can't bind to args.server_port get another suitable port from
the OS and advise the user that we did this in the UI.
* Make `ui=web` mode use 'SHARK AI Studio' as its title. This makes it
consistent with `ui=app`.
* Replace the generic gradio favicon with a nod swirl one instead.
2023-11-09 10:53:28 -06:00
Ean Garvey
a07d542400 (Studio) Disable SD tunings and sub-model downloads (#1944)
* sets --no-use_tuned and --import_mlir as defaults in SHARK Studio.
2023-11-07 15:55:30 -06:00
Stefan Kapusniak
ad55cb696f SD/API: Add missing A1111 APIs to Shark to support koboldcpp image generation (#1924)
* SD/API: Add missing a1111 API features for Koboldcpp

* Refactors SD api functions into their own file
* Adds the following apis implemented by a1111 as needed by koboldcpp:
   - adds /sdapi/v1/sd-models (lists available models)
   - adds /sdapi/v1/options (only the bare minimum needed)
* Adds optional CORS support, use the '--api_accept_origin' command line
argument to activate and configure.
* Extends existing APIs to include optional sampler/scheduler selection
* Extends /sdapi/v1/textimg to recognise the method used by koboldcpp
to select the model.
* Where possible take values not provided to the API in the request from
the existing relevant command line parameters rather than hardcoding
them.
* return a 400 response when a request doesn't have required properties.
* changed default schedulers and models for some apis to ones that
actually seem to work.
* Update api_test.py to include the new APIs.
* Update api_test.py to include a '--verbose' command line option.

* SD/API: Take more API values from args

* Take LoRA from '--use_lora' command line arg if specified
* Take device from '--device' command line arg if specified (substring
match, so a short name such as 'vulkan://0' should work)

* SD/API: add more endpoints and pydantic typing

* Mount the whole of /sdapi from index.py as a FastAPI application,
rather than each endpoint individually
* Add the following additional API endpoints:
  * /sdapi/v1/samplers
  * /sdapi/v1/cmd-flags
* Make scheduler/sampler selection checking and fallback much more
robust.
* Support aliasing some A1111 scheduler/sampler names to the diffusers
ones we are using.
* Expand response /sdapi/v1/options to add a few more things.
* Split non-api functions and variables into their own utils.py file.
* Support 'n_iter' request property and the return of multiple images
from generation endpoints. Equivalent of '--batch_count', batch_size
is stil hardcoded at 1
* Include (some) hires_fix request properties in txt2img endpoint
* Rework endpoints using pydantic model classes for better request
validation and so we get much improved swagger api docs at
/sdapi/docs and redoc at /sdapi/redoc

* SD/API Delete commented out code from index.py

* Delete some code that is no longer needed by the SD API in index.py
(and one line sdapi_v1.py) that I'd previously only commented out.

* SD/UI: Add shark_sd_koboldcpp.md document

* Add documentation on how to set up Koboldcpp with SHARK
* Link this and the existing blender set up document from the main
README.md

* SD/API Improve stencil options in img2img endpoint

In /sdapi/v1/img2img:
  * Add zoedepth to the controlnet use_stencil options
  * Require and use second image as stencil mask for controlnet scribble
2023-11-06 15:20:19 -06:00
Jakub Kuderski
488a172292 [vicuna.py] Allow to pass extra arguments to iree-compile (#1935)
Add a new flag `-Xiree_compile` to forward extra compiler arguments to
`iree-compile`. This flag can be set multiple times to pass more than
one extra argument.
2023-11-06 12:12:34 -05:00
Stanley Winata
500c4f2306 [compile utils] Fix ROCM to not expect config.id as a default. (#1939) 2023-11-06 08:44:53 -08:00
Vivek Khandelwal
92b694db4d Add support for Falcon-40b-GPTQ 2023-11-06 19:49:19 +05:30
Vivek Khandelwal
322874f7f9 Fix issue in Falcon-GPTQ 2023-11-03 11:48:36 +05:30
Ean Garvey
5001db3415 Add 7800xt to target triples explicitly. (#1928) 2023-11-01 17:11:45 -05:00
Vivek Khandelwal
71846344a2 Add sharded Falcon-GPTQ support
This commit adds the support for sharded Falcon-7b-GPTQ and
Falcon-180B-GPTQ. This commit also adds the support for 4-way
sharding of the Falcon model for the device ROCM.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-11-01 12:11:44 +05:30
gpetters94
72e27c96fc Add ZoeDepth (#1834)
* Add ZoeDepth

* Add einops to Studio imports.

* Specify ref for forked torch.hub repos.

* Unpin timm.

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Ean Garvey <garveyej@gmail.com>
2023-10-30 11:57:45 -05:00
PhaneeshB
7963abb8ec remove caching for rocm args 2023-10-29 07:07:57 +05:30
Ean Garvey
98244232dd Add smoothquant OPT to examples. (#1922) 2023-10-27 12:32:12 -05:00
PhaneeshB
679a452139 fix calls and remove unused imports for check_device_drivers 2023-10-27 10:30:40 +05:30
PhaneeshB
72c0a8abc8 remove dependency on external commands for driver installation check 2023-10-27 10:30:40 +05:30
Vivek Khandelwal
ea920f2955 Add sharded Falcon support 2023-10-26 21:53:25 +05:30
Phaneesh Barwaria
486202377a update dependency on rocm/hip info command (#1900)
* add support for rocm flags

* add rocm target flag to chat args

* rm rocm libs dependency message
2023-10-26 15:18:25 +05:30
Sungsoon Cho
0c38c33d0a Add opt_causallm_samples.py. (#1916) 2023-10-25 11:52:51 -05:00
Ean Garvey
841773fa32 Updates to opt_causallm example (#1905)
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
2023-10-24 10:54:39 -07:00
Stefan Kapusniak
0361db46f9 SD: Fix unet untuned opt_flags (#1912)
* correct my sloppy copy/paste for the untuned unet default compilation
flags that introduced an extra 'detach' into what should have been
'iree-global-opt-convert-1x1-filter-conv2d-to-matmul'
2023-10-24 12:47:33 -05:00
xzuyn
a012433ffd Save hiresfix info if used (#1914) 2023-10-24 12:45:10 -05:00
xzuyn
5061193da3 Move Generate, Randomize Seed, & Stop Batch to same positions as txt2img (#1915) 2023-10-24 12:44:39 -05:00
xzuyn
bff48924be LLaMa 2 Chat template fix (#1913) 2023-10-23 18:51:15 -05:00
Stefan Kapusniak
825b36cbdd Fix MLIR Textual PassPipeline Error (#1910) 2023-10-22 07:39:52 -07:00
Stefan Kapusniak
134441957d SD - Fix civitai download on Windows +improvements (#1907) 2023-10-21 11:17:41 -07:00
Stefan Kapusniak
7cd14fdc47 SD/UI: Use a single model selection box on UI tabs (#1906)
* Allow entry of a huggingface model id or civitai download url to be
done in the main model selection dropdown on SD tabs
* Remove separate textbox for entering huggingface model id or civitai
download url on SD Tabs
* Remove 'None' option from the model selection dropdown (no longer
needed) on SD tabs
* Update png metadata drop zone on txt2img tab to work with a single
argument for model selection
* Update UI generate functions on SD tabs to work with single argument
model selection
* Update API code for changes to the UI generate functions
* Move info about the custom model path to the logging textarea on SD
tabs
2023-10-21 10:06:05 -07:00
Ean Garvey
e6cb5cef57 Add --additional_runtime_args option and use in OPT example. (#1855)
* Add --additional_runtime_args option and use in OPT example.

Fix the func name. (#1838)

Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
2023-10-19 13:29:39 -05:00
Huang Qi
66abee8e5b SharkInference: Fix various examples and README.md (#1903)
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
2023-10-19 09:28:36 -05:00
Ean Garvey
4797bb89f5 Stringify path for ireec.compile_file (#1901)
* Stringify path for ireec.compile_file

* Update test-models.yml
2023-10-18 14:59:23 -05:00
Vivek Khandelwal
205e57683a Modify Falcon-180b-GPTQ sharded pipeline 2023-10-17 20:26:01 +05:30
Vivek Khandelwal
2866d665ee Fix Sharded Falcon-180b-GPTQ Pipeline 2023-10-17 20:26:01 +05:30
Stefan Kapusniak
71d25ec5d8 SD: Fix repeatable seeds when intial seed is random (#1893) 2023-10-14 22:50:42 -07:00
Vivek Khandelwal
202ffff67b Add support for sharded Falcon model 2023-10-13 22:05:10 +05:30
Ean Garvey
0b77059628 Add matmul reassociation flags (#1891) 2023-10-12 20:12:37 -05:00
Stefan Kapusniak
a208302bb9 Fix repeatable seeds consistency over batch counts (#1889)
* Set the input seed for the random number generator when
generating repeatable seeds to exclude any negative numbers
in the parsed seed input.  The makes seeds generated for
different batch counts consistent where they have the same
input for the initial seed or set of seeds.
2023-10-12 17:15:19 -05:00
Vivek Khandelwal
b83d32fafe Fix Falcon GPTQ Pipeline 2023-10-11 20:09:32 +05:30
Vivek Khandelwal
0a618e1863 Add support for Falcon GPTQ 2023-10-11 10:47:48 +05:30
Phaneesh Barwaria
a731eb6ed4 Macos fixes (#1883)
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
2023-10-09 23:36:12 -07:00
Ean Garvey
2004d16945 Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)
This reverts commit 9f0a421764.
2023-10-09 18:01:44 -07:00
Gaurav Shukla
6e409bfb77 fix else if syntax error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 06:23:56 +05:30
Gaurav Shukla
77727d149c [warning] Fix dropdown warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 05:18:43 +05:30
Ean Garvey
66f6e79d68 Split CPU/GPU definitions conditionally outside of torch contexts. (#1879) 2023-10-09 16:46:41 -07:00
Ean Garvey
3b825579a7 (LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
2023-10-09 14:37:35 -05:00
Abhishek Varma
9f0a421764 [SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-09 13:01:37 -05:00
Gaurav Shukla
c28682110c [chatbot] Flag to add system prompt
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-09 22:17:39 +05:30
Ean Garvey
caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00
Ean Garvey
8614a18474 Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path.

* Fix formatting.
2023-10-06 12:27:12 -07:00
Jakub Kuderski
86c1c0c215 Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the
default number of iterations to 5.

Example:
```
Number of iterations: 5
Prefill: avg. 0.03 s, stddev 0.00
Decode: avg. 43.34 tokens/s, stdev 0.13
```

Also remove the -2 in the number of generated tokens -- I did not find
any evidence we need it.
2023-10-06 10:03:07 -07:00
Daniel Garvey
8bb364bcb8 enforce fp32 accumulates for cpu (#1873) 2023-10-06 11:34:49 -05:00
Daniel Garvey
7abddd01ec argmax inside model + brevitas pin (#1872) 2023-10-05 20:15:21 -07:00
Abhishek Varma
2a451fa0c7 [Llama2] Add a standalone utility for dynamic and combining IRs
-- This script adds a standalone utility for converting Llama IRs
   to dynamic and combining them as well.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-05 20:01:06 +05:30
Jakub Kuderski
9c4610b9da Add microbenchmark mode to vicuna CLI (#1864)
Add flags to enable a non-internactive mode for microbenchmarking llama
models. In this mode, the system and user prompts are specified with CLI
flags, and the number of generated tokens and iterations is fixed.

Also move the stats below the response and trim any response blankspace.
2023-10-05 00:12:08 -04:00
powderluv
a38cc9d216 Update vulkan_utils.py for Radeon 780m igpu (#1866) 2023-10-04 20:33:07 -07:00
Jakub Kuderski
1c382449ec [vulkan] Print note about module load times. NFC. (#1862)
Print a note ahead of a potentially long inactivity to set the right expectations.

Separately, we should add progress to the UI and make this loading faster.
2023-10-03 17:27:27 -04:00
Gaurav Shukla
7cc9b3f8e8 [llama cli] Fix llama cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-03 20:39:53 +05:30
Gaurav Shukla
e54517e967 [UI] Disable config generator, lora train and model manager (#1858)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-02 22:34:40 -07:00
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
Gaurav Shukla
e930ba85b4 [os] Remove os dependency from vmfb naming (#1854)
Also fixes a small ui issue for chatbot.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 12:38:17 -05:00
Gaurav Shukla
cd732e7a38 [chatbot] split execution time to prefill and decode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
8e0f8b3227 [ui] Update chatbot UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
b8210ef796 [chatbot] Re-instantiate the chatbot object if device id changes
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
Vivek Khandelwal
9681d494eb Update decomp list and shark trainer for DLRM 2023-09-06 21:24:50 +05:30
Gaurav Shukla
ede6bf83e2 [vicuna] Disabling the IR generation path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-06 20:13:17 +05:30
Ean Garvey
2c2693fb7d Fix torchvision versioning in Linux importer setup. (#1809) 2023-09-05 12:57:03 -05:00
Vivek Khandelwal
1d31b2b2c6 Fix StableHLO Compilation flag 2023-09-05 21:32:33 +05:30
Gaurav Shukla
d2f64eefa3 [chatbot] Remove few outdated models from list (#1814) 2023-09-04 09:26:32 -07:00
Abhishek Varma
87ae14b6ff [SD] Add sdpfa decomposition + update IREE flag
-- This commit adds Scaled Dot Product Flash Attention's decomposition
   in shark_importer.
-- It also updates `iree-flow-enable-data-tiling` to `iree-opt-data-tiling`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-04 18:03:53 +05:30
Phaneesh Barwaria
1ccafa1fc1 fix llama2-70b rewrite tensor dim 2023-09-01 17:27:06 +05:30
jinchen62
4c3d8a0a7f Enable downloading vmfb/mlir for webui (#1807) 2023-08-31 11:05:47 -07:00
jinchen62
3601dc7c3b Fix llama2 13b combined ir (#1803) 2023-08-28 11:34:44 -07:00
Daniel Garvey
671881cf87 Llama2 70b (#1783)
* llama2 70b IR gen

* fix IR sec llama2 + debug

* llama270b

---------

Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-08-25 23:04:28 -07:00
Gaurav Shukla
4e9be6be59 [chatbot] Add debug as class attribute (#1799)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-25 21:46:29 -07:00
Ean Garvey
9c8cbaf498 Add support for ROCM (Windows) in Studio + compile utils (#1770)
* WIP: MSVC ROCM support for SHARK Studio

* Make get_iree_rocm_args platform-agnostic.

* Update stable_args.py

* Update rocm arg handling in SD utils

* Guard quantization imports.

Co-authored-by: jam https://github.com/jammm
2023-08-25 20:56:05 -07:00
Ean Garvey
9e348a114e Revert changes process_skipfiles.py (#1798)
Keeps a small typo fix but reverts the rest of changes to this file from 450c231171
2023-08-25 15:31:49 -07:00
jinchen62
51f90a4d56 Update conversion passes for brevitas quant op (#1795) 2023-08-25 17:28:07 -05:00
Abhishek Varma
310d5d0a49 Fix llama2 13b crashing + add spec file for CLI execution of Llama (#1797)
* [Llama2] Add a fix for Llama2 13B downloading/crashing

-- This commit fixes downloading/crashing of llama2 13B on wrong
   .mlir file.
-- Also adds support for downloading vmfb from shark_tank in CLI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [llama2] Add a spec file to run Llama/Vicuna CLI exe

-- This commit adds a spec file to run Llama/Vicuna CLI exe.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-25 09:36:09 -05:00
Ean Garvey
9697981004 Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
2023-08-25 07:11:11 -07:00
Ean Garvey
450c231171 Add tokenizers to requirements.txt (#1790)
* Add tokenizers to requirements and pin version

* Update process_skipfiles.py
2023-08-24 19:44:04 -05:00
Ean Garvey
07f6f4a2f7 Add a short README for the OPT examples and small tweaks. (#1793)
* Small changes to OPT example.

* Update opt README.

* Add a few modes to batch script.

* Update README.md
2023-08-24 17:26:11 -07:00
jinchen62
610813c72f Add iree flag to strip assertions (#1791) 2023-08-24 10:51:19 -07:00
Ean Garvey
8e3860c9e6 Remove flags that are default in upstream IREE (#1785)
* Remove index bits flags now set by default

* Update shark_studio_imports.py
2023-08-24 11:57:54 -05:00
xzuyn
e37d6720eb Add Hires Fix (#1787)
* improper test hiresfix

* add sliders & use `clear_cache`

* add resample choices & fix step adjustment

* add step adjustment to img2img

* add resample options to img2img

* simplify hiresfix
- import `img2img_inf` from `img2img_ui.py` instead of just copying it into `txt2img_ui.py`

* set `hri` to None after using

* add more resample types, and don't show output until hiresfix is done

* cleaner implementation

* ran black

* ran black again with jupyter dependencies
2023-08-24 09:01:41 -07:00
Vivek Khandelwal
16160d9a7d Fix combine mlir script 2023-08-24 19:10:49 +05:30
Sungsoon Cho
79075a1a07 Opt perf (#1786)
* Define command line args, model-name, max-seq-len, platform, etc.

* Add usage example.

* Add opt_perf_comparision_batch.py.

* Use shlex instead.
2023-08-24 08:33:12 -05:00
Abhishek Varma
db990826d3 Add Llama2 13B int4 fp16 support (#1784)
Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-23 10:00:32 -07:00
gpetters94
7ee3e4ba5d Add stencil_unet_512 support (#1778)
This should fix any remaining issues with stencils and long prompts.
2023-08-22 12:23:46 -04:00
Vivek Khandelwal
05889a8fe1 Add LLaMa2-int4-fp16 support (#1782) 2023-08-22 07:45:50 -07:00
jinchen62
b87efe7686 Fix venv setup for brevitas (#1779) 2023-08-21 11:58:51 -07:00
gpetters94
82b462de3a Fix stencils for long prompts (#1777) 2023-08-19 00:26:51 -07:00
Daniel Garvey
d8f0f7bade replace public with private (#1776)
unload footguns
2023-08-18 14:22:46 -07:00
gpetters94
79bd0b84a1 Fix an issue with diffusers>0.19.3 (#1775) 2023-08-18 14:06:06 -04:00
jinchen62
8738571d1e Adapt the change of brevitas custom op name (#1772) 2023-08-17 14:24:43 -07:00
Gaurav Shukla
a4c354ce54 [version] Pin diffusers==0.19.3
Once the latest works with LORA train, unpin it.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
cc53efa89f [cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
9ae8bc921e [chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
32eb78f0f9 [chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 19:14:17 +05:30
Ean Garvey
cb509343d9 Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc.
- fix torch benchmarks in SharkBenchmarkRunner
- generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir
- decouple SD gen from tank/generate_sharktank for now
2023-08-16 17:48:47 -05:00
powderluv
6da391c9b1 update signtool to use /fd certHash 2023-08-15 15:11:40 -07:00
Ean Garvey
9dee7ae652 fix tkinter window (#1766) 2023-08-15 13:23:09 -07:00
Ean Garvey
343dfd901c Update SHARK-Runtime links to SRT (#1765)
* Update nightly.yml

* Update setup_venv.ps1

* Update CMakeLists.txt

* Update shark_iree_profiling.md

* Update setup_venv.sh

* Update README.md

* Update .gitmodules

* Update CMakeLists.txt

* Update README.md

* fix signtool flags

* Update nightly.yml

* Update benchmark_utils.py

* uncomment tkinter launch
2023-08-15 12:40:44 -07:00
Ean Garvey
57260b9c37 (Studio) Add hf-hub to pyinstaller metadata (#1761) 2023-08-14 23:01:50 -05:00
Ean Garvey
18e7d2d061 Enable vae tunings for rdna3. (#1764) 2023-08-14 21:00:14 -07:00
Stanley Winata
51a1009796 Add Forward method to SHARKRunner and fix examples. (#1756) 2023-08-14 19:20:37 -07:00
Daniel Garvey
045c3c3852 enable iree-opt-const-expr-hoisting in vicuna (#1742)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-08-14 18:43:42 -07:00
Ean Garvey
0139dd58d9 Specify max allocation size in IREE compile args. (#1760) 2023-08-14 15:43:09 -05:00
Ean Garvey
c96571855a prevents recompiles for cuda benchmarks + update benchmark_module path (#1759)
* xfail resnet50_fp16

* Fix cuda benchmarks and prevent recompilation.
2023-08-14 15:30:32 -05:00
PhaneeshB
4f61d69d86 add support passing iree flags for LLMs 2023-08-15 00:22:56 +05:30
Phaneesh Barwaria
531d447768 set default allocator for metal device creation (#1755) 2023-08-14 06:17:52 -07:00
Vivek Khandelwal
16f46f8de9 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
c4723f469f Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d804f45a61 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d22177f936 Update requirements.txt 2023-08-14 14:32:19 +05:30
George Petterson
75e68f02f4 Remove CUDNN 2023-08-14 14:32:19 +05:30
Gaurav Shukla
4dc9c59611 [chatbot] Add tokens generated per second (#1753) 2023-08-13 11:25:41 -07:00
Gaurav Shukla
18801dcabc [chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-13 18:39:22 +05:30
Gaurav Shukla
3c577f7168 [vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-10 11:26:03 -07:00
Stefan Kapusniak
f5e4fa6ffe UI/Web - Revert tab order (#1724)
* Revert ui tab order

* Reverts the tab order, so that SD, LLM, and Experimental are grouped
together again as far as is possible.
* Labelled "Generate Sharding Config" as experimental as pressing the
'Get Model Config' errors for me.

* Fix formatting in index.py
2023-08-10 11:25:36 -07:00
powderluv
48de445325 Enable caching and disable vma (#1746)
* Enable caching allocator by default

Going to toggle VMA off too and this is required for performance.  Will have to monitor in the wild reports.

* Disable VMA

Disable VMA
2023-08-10 10:49:44 -07:00
Gaurav Shukla
8e90f1b81a [vicuna] add default config in case of sharded vicuna
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
2023-08-10 21:28:08 +05:30
Vivek Khandelwal
e8c1203be2 Fix vicuna script (#1745) 2023-08-10 06:11:14 -07:00
Vivek Khandelwal
e4d7abb519 Final patch for fixing Langchain token streaming issue (#1744) 2023-08-09 10:09:41 -07:00
powderluv
96185c9dc1 pin safetensors to 0.3.1 (#1740) 2023-08-08 19:24:44 -07:00
powderluv
bc22a81925 re-enable constant folding (#1739)
Tested and works well. (modulo unrelated driver issue)
2023-08-08 17:17:38 -07:00
Eliasj42
5203679f1f Bandaid fix 2 (#1728)
* download all mlirs

* fixed install method

* download all mlirs (#1727)

Co-authored-by: Elias Joseph <elias@nod-labs.com>

* added taggs

* fix name check for file existence

* Remove SD from all_models.csv (#1706)

Removes SD from pytests as it has its own test suite.

* gpt_langchain.py fixes for pydantic (#1722)

* removed dead code

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
2023-08-08 12:14:57 -05:00
Vivek Khandelwal
bf073f8f37 [Langchain] Expand pipelines to fix token streaming issue 2023-08-08 10:27:23 +05:30
Stella Laurenzo
cec6eda6b4 Optimize device enumeration overhead and log details on long operations. (#1734)
* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
2023-08-07 17:20:53 -07:00
Stella Laurenzo
9e37e03741 Clearly differentiate phases of loading modules to better understand if things are taking a long time. (#1733) 2023-08-07 14:03:12 -07:00
Stefan Kapusniak
9b8c4401b5 gpt_langchain.py fixes for pydantic (#1722) 2023-08-07 00:55:38 -07:00
Ean Garvey
a9f95a218b Remove SD from all_models.csv (#1706)
Removes SD from pytests as it has its own test suite.
2023-08-05 15:55:52 -05:00
PhaneeshB
872bd72d0b fix name check for file existence 2023-08-05 21:33:53 +05:30
Eliasj42
fd1c4db5d0 download all mlirs (#1727)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 18:22:06 -05:00
Daniel Garvey
759664bb48 add py files to pyinstaller for shark (#1723) 2023-08-04 14:10:43 -07:00
Daniel Garvey
14fd0cdd87 add missing subprocess import (#1721) 2023-08-04 15:15:22 -05:00
Daniel Garvey
a57eccc997 fix lint (#1720) 2023-08-04 14:54:33 -05:00
Daniel Garvey
a686d7d89f temporarily disable langchain stuff in webui (#1719)
its breaking the exe
2023-08-04 12:48:06 -07:00
Eliasj42
ed484b8253 added functionality for int8 vicuna and 4 shards (#1712)
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 14:05:05 -05:00
gpetters94
7fe57ebaaf Add vector database and add support on the web UI (#1699) 2023-08-04 13:47:19 -04:00
Nithin Meganathan
c287fd2be8 Add GPU ID's in model_confg.json by default for manual annotation (#1718) 2023-08-04 12:46:27 -05:00
Gaurav Shukla
51ec1a1360 [vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 11:46:53 -05:00
Gaurav Shukla
bd30044c0b [Shard] Add sharding generation in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 21:51:14 +05:30
Ean Garvey
c9de2729b2 Add flag for toggling constant folding. (#1714) 2023-08-04 04:55:52 -07:00
Vivek Khandelwal
a5b13fcc2f [Langchain] Patch for fixing streaming of tokens (#1709) 2023-08-03 10:06:49 -07:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
Vivek Khandelwal
98fb6c52df Expand pipelines to fix streaming of tokens 2023-07-31 22:11:01 +05:30
Stefan Kapusniak
206c1b70f4 UI/Web: Reorder tabs to separate SD and LLM (#1701)
Shuffle the tabs around so that:

* All the SD tabs are together
* All the LLM tabs are together
* All the experimental tabs are together
2023-07-29 22:25:30 -04:00
PhaneeshB
cdb037ee54 use shark_args for vulkan debug utils flag 2023-07-30 07:54:26 +05:30
PhaneeshB
ce2fd84538 fix cpu device name for SharkStudio 2023-07-30 07:54:26 +05:30
PhaneeshB
4684afad34 update upscalar example 2023-07-28 21:06:28 +05:30
PhaneeshB
8d65456b7a Move vulkan runtime flags to shark_args 2023-07-28 21:06:28 +05:30
PhaneeshB
d6759a852b add vulkan vma alloc flag 2023-07-28 21:06:28 +05:30
Daniel Garvey
ab57af43c1 Couple of fixes for vicuna.py (#1696)
* mega vicuna merge pt 2

* add fallback to ensure compile is called
2023-07-27 15:53:05 -07:00
jinchen62
4d5c55dd9f Fix vicuna script (#1697) 2023-07-27 17:24:26 -05:00
Vivek Khandelwal
07399ad65c [Langchain] Remove unused code (#1698) 2023-07-27 11:59:54 -05:00
Vivek Khandelwal
776a9c2293 Fix for Langchain (#1694)
For CPU, remove max time stopping criteria
Fix web UI issue
2023-07-26 09:00:23 -07:00
Eliasj42
9d399eb988 fixed bug where device_idx was hardcoded (#1693)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-07-25 19:00:13 -05:00
Vivek Khandelwal
927b662aa7 Add Langchain SHARK Compilation support for all paths 2023-07-25 22:15:42 +05:30
Abhishek Varma
47f8a79c75 [MiniGPT4] Add MiniGPT4 to SHARK (#1554)
* [MiniGPT4] Add MiniGPT4 to SHARK

-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* Add int8 support for MiniGPT4

-- This commit adds int8 support for MiniGPT4.

Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>

* Update .spec for MiniGPT4's config files

* black format MiniGPT4

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>
2023-07-25 09:42:27 -07:00
Stefan Kapusniak
289f983f41 SD - Implement seed arrays for batch runs (#1690)
* SD Scripts and UI tabs that support batch_count can now take a
string containing a JSON array, or a list of integers, as their seed
input.
* Each batch in a run will now take the seed specified at the
corresponding array index if one exists. If there is no seed at
that index, the seed value will be treated as -1 and a random
seed will be assigned at that position. If an integer rather than
a list or json array has been, everything works as before.
* UI seed input controls are now Textboxes with info lines about
the seed formats allowed.
* UI error handling updated to be more helpful if the seed input is
invalid.
2023-07-24 19:22:34 -07:00
Daniel Garvey
453e46562f mega vicuna merge pt 2 (#1685) 2023-07-24 12:42:20 -05:00
Gaurav Shukla
5497af1f56 [config] Add support for uploading sharding config file in chatbot (#1689)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-07-24 10:18:03 -07:00
Vivek Khandelwal
f3cb63fc9c Fix Langchain multiple device isssue (#1688) 2023-07-24 08:03:46 -07:00
Vivek Khandelwal
d7092aafaa Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
2023-07-21 21:56:27 +05:30
Vivek Khandelwal
a415f3f70e Fix Langchain Prompt issue and add web UI support (#1682) 2023-07-21 06:36:55 -07:00
Vivek Khandelwal
c292e5c9d7 Add Langchain CPU support and update requirements 2023-07-20 18:53:34 +05:30
Vivek Khandelwal
03c4d9e171 Add support for Llama-2-70b for web and cli, and for hf_auth_token 2023-07-20 14:57:48 +05:30
jinchen62
3662224c04 Update brevitas requirement (#1677)
also clean up useless args

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 22:03:32 -07:00
Vivek Khandelwal
db3f222933 Revert "Add Llama2 70B option in CLI and WebUI (#1673)" (#1679)
This reverts commit 41e5088908.
2023-07-19 22:02:48 -07:00
Stefan Kapusniak
68b3021325 Fixes cosmetic problems with Gradio 3.37.0 (#1676)
* Fix nod-ai logo having a white border
* Fix control labels having a black background
* Remove extra lower border below Save Prompt checkboxes in Txt2Img UI
2023-07-19 17:28:53 -07:00
AyaanShah2204
336469154d added copy-metadata for pyyaml (#1678) 2023-07-19 17:27:25 -07:00
163 changed files with 19563 additions and 13478 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py

View File

@@ -51,11 +51,11 @@ jobs:
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
@@ -104,7 +104,7 @@ jobs:
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -144,7 +144,7 @@ jobs:
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

View File

@@ -112,9 +112,10 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
@@ -122,7 +123,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -136,7 +137,8 @@ jobs:
source shark.venv/bin/activate
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
# disabled due to a low-visibility memory issue with pytest on macos.
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
@@ -144,8 +146,8 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'

11
.gitignore vendored
View File

@@ -182,7 +182,7 @@ generated_imgs/
# Custom model related artefacts
variants.json
models/
/models/
# models folder
apps/stable_diffusion/web/models/
@@ -193,3 +193,12 @@ stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SHARK-Runtime.git
url =https://github.com/nod-ai/SRT.git
branch = shark-06032022

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
@@ -254,7 +254,6 @@ if you want to instead incorporate this into a python script, you can pass the `
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -297,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -320,12 +319,17 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:

View File

@@ -5,6 +5,7 @@
1.) Install all the dependencies by running:
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
```
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.

View File

@@ -2,7 +2,7 @@ import copy
import torch
from evaluate_params import eval_func_param_names
from gen import get_score_model, get_model, evaluate, check_locals
from gen import Langchain
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs
@@ -87,7 +87,7 @@ def run_cli( # for local function:
# unique to this function:
cli_loop=None,
):
check_locals(**locals())
Langchain.check_locals(**locals())
score_model = "" # FIXME: For now, so user doesn't have to pass
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
@@ -98,16 +98,20 @@ def run_cli( # for local function:
from functools import partial
# get score model
smodel, stokenizer, sdevice = get_score_model(
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
get_score_model, exclude_names=["reward_type"], **locals()
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
model, tokenizer, device = get_model(
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(get_model, exclude_names=["reward_type"], **locals())
**get_kwargs(
Langchain.get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
@@ -121,11 +125,11 @@ def run_cli( # for local function:
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
evaluate,
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
evaluate,
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()

View File

@@ -1,402 +0,0 @@
import inspect
import os
import traceback
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from evaluate_params import eval_func_param_names, eval_extra_columns
from gen import get_context, get_score_model, get_model, evaluate, check_locals
from prompter import Prompter
from utils import clear_torch_cache, NullContext, get_kwargs
def run_eval( # for local function:
base_model=None,
lora_weights=None,
inference_server=None,
prompt_type=None,
prompt_dict=None,
debug=None,
chat=False,
chat_context=None,
stream_output=None,
eval_filename=None,
eval_prompts_only_num=None,
eval_prompts_only_seed=None,
eval_as_output=None,
examples=None,
memory_restriction_level=None,
# for get_model:
score_model=None,
load_8bit=None,
load_4bit=None,
load_half=None,
load_gptq=None,
use_safetensors=None,
infer_devices=None,
tokenizer_base_model=None,
gpu_id=None,
local_files_only=None,
resume_download=None,
use_auth_token=None,
trust_remote_code=None,
offload_folder=None,
compile_model=None,
# for evaluate args beyond what's already above, or things that are always dynamic and locally created
temperature=None,
top_p=None,
top_k=None,
num_beams=None,
max_new_tokens=None,
min_new_tokens=None,
early_stopping=None,
max_time=None,
repetition_penalty=None,
num_return_sequences=None,
do_sample=None,
langchain_mode=None,
langchain_action=None,
top_k_docs=None,
chunk=None,
chunk_size=None,
document_choice=None,
# for evaluate kwargs:
src_lang=None,
tgt_lang=None,
concurrency_count=None,
save_dir=None,
sanitize_bot_response=None,
model_state0=None,
max_max_new_tokens=None,
is_public=None,
max_max_time=None,
raise_generate_gpu_exceptions=None,
load_db_if_exists=None,
dbs=None,
user_path=None,
detect_user_path_changes_every_query=None,
use_openai_embedding=None,
use_openai_model=None,
hf_embedding_model=None,
db_type=None,
n_jobs=None,
first_para=None,
text_limit=None,
verbose=None,
cli=None,
reverse_docs=None,
use_cache=None,
auto_reduce_chunks=None,
max_chunks=None,
model_lock=None,
force_langchain_evaluate=None,
model_state_none=None,
):
check_locals(**locals())
if eval_prompts_only_num > 0:
np.random.seed(eval_prompts_only_seed)
example1 = examples[-1] # pick reference example
examples = []
responses = []
if eval_filename is None:
# override default examples with shareGPT ones for human-level eval purposes only
eval_filename = (
"ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json"
)
if not os.path.isfile(eval_filename):
os.system(
"wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s"
% eval_filename
)
import json
data = json.load(open(eval_filename, "rt"))
# focus on data that starts with human, else likely chopped from other data
turn_start = 0 # odd in general
data = [
x
for x in data
if len(x["conversations"]) > turn_start + 1
and x["conversations"][turn_start]["from"] == "human"
and x["conversations"][turn_start + 1]["from"] == "gpt"
]
for i in sorted(
np.random.randint(0, len(data), size=eval_prompts_only_num)
):
assert data[i]["conversations"][turn_start]["from"] == "human"
instruction = data[i]["conversations"][turn_start]["value"]
assert (
data[i]["conversations"][turn_start + 1]["from"] == "gpt"
)
output = data[i]["conversations"][turn_start + 1]["value"]
examplenew = example1.copy()
assert (
not chat
), "No gradio must use chat=False, uses nochat instruct"
examplenew[
eval_func_param_names.index("instruction_nochat")
] = instruction
examplenew[
eval_func_param_names.index("iinput_nochat")
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
else:
# get data, assume in correct format: json of rows of dict of instruction and output
# only instruction is required
import json
data = json.load(open(eval_filename, "rt"))
for i in sorted(
np.random.randint(0, len(data), size=eval_prompts_only_num)
):
examplenew = example1.copy()
instruction = data[i]["instruction"]
output = data[i].get("output", "") # not required
assert (
not chat
), "No gradio must use chat=False, uses nochat instruct"
examplenew[
eval_func_param_names.index("instruction_nochat")
] = instruction
examplenew[
eval_func_param_names.index("iinput_nochat")
] = "" # no input
examplenew[
eval_func_param_names.index("context")
] = get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)
num_examples = len(examples)
scoring_path = "scoring"
os.makedirs(scoring_path, exist_ok=True)
if eval_as_output:
used_base_model = "gpt35"
used_lora_weights = ""
used_inference_server = ""
else:
used_base_model = str(base_model.split("/")[-1])
used_lora_weights = str(lora_weights.split("/")[-1])
used_inference_server = str(inference_server.split("/")[-1])
eval_out_filename = "df_scores_%s_%s_%s_%s_%s_%s_%s.parquet" % (
num_examples,
eval_prompts_only_num,
eval_prompts_only_seed,
eval_as_output,
used_base_model,
used_lora_weights,
used_inference_server,
)
eval_out_filename = os.path.join(scoring_path, eval_out_filename)
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
with context_class(device):
# ensure was set right above before examples generated
assert (
not stream_output
), "stream_output=True does not make sense with example loop"
import time
from functools import partial
# get score model
smodel, stokenizer, sdevice = get_score_model(
reward_type=True,
**get_kwargs(
get_score_model, exclude_names=["reward_type"], **locals()
)
)
if not eval_as_output:
model, tokenizer, device = get_model(
reward_type=False,
**get_kwargs(
get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
tokenizer_base_model=tokenizer_base_model,
lora_weights=lora_weights,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
)
model_state = dict(model=model, tokenizer=tokenizer, device=device)
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
evaluate,
model_state,
my_db_state,
**get_kwargs(
evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()
)
)
else:
assert eval_prompts_only_num > 0
def get_response(*args, exi=0):
# assumes same ordering of examples and responses
yield responses[exi]
fun = get_response
t0 = time.time()
score_dump = []
score_avg = 0
score_median = 0
for exi, ex in enumerate(examples):
clear_torch_cache()
instruction = ex[eval_func_param_names.index("instruction_nochat")]
iinput = ex[eval_func_param_names.index("iinput_nochat")]
context = ex[eval_func_param_names.index("context")]
clear_torch_cache()
print("")
print("START" + "=" * 100)
print(
"Question: %s %s"
% (instruction, ("input=%s" % iinput if iinput else ""))
)
print("-" * 105)
# fun yields as generator, so have to iterate over it
# Also means likely do NOT want --stream_output=True, else would show all generations
t1 = time.time()
gener = (
fun(*tuple(ex), exi=exi) if eval_as_output else fun(*tuple(ex))
)
for res_fun in gener:
res = res_fun["response"]
extra = res_fun["sources"]
print(res)
if smodel:
score_with_prompt = False
if score_with_prompt:
data_point = dict(
instruction=instruction,
input=iinput,
context=context,
)
prompter = Prompter(
prompt_type,
prompt_dict,
debug=debug,
chat=chat,
stream_output=stream_output,
)
prompt = prompter.generate_prompt(data_point)
else:
# just raw input and output
if eval_prompts_only_num > 0:
# only our own examples have this filled at moment
assert iinput in [
None,
"",
], iinput # should be no iinput
if not (chat_context and prompt_type == "human_bot"):
assert context in [
None,
"",
], context # should be no context
prompt = instruction
if memory_restriction_level > 0:
cutoff_len = (
768 if memory_restriction_level <= 2 else 512
)
else:
cutoff_len = tokenizer.model_max_length
inputs = stokenizer(
prompt,
res,
return_tensors="pt",
truncation=True,
max_length=cutoff_len,
)
try:
score = (
torch.sigmoid(smodel(**inputs).logits[0].float())
.cpu()
.detach()
.numpy()[0]
)
except torch.cuda.OutOfMemoryError as e:
print(
"GPU OOM 1: question: %s answer: %s exception: %s"
% (prompt, res, str(e)),
flush=True,
)
traceback.print_exc()
score = 0.0
clear_torch_cache()
except (Exception, RuntimeError) as e:
if (
"Expected all tensors to be on the same device"
in str(e)
or "expected scalar type Half but found Float"
in str(e)
or "probability tensor contains either" in str(e)
or "cublasLt ran into an error!" in str(e)
):
print(
"GPU error: question: %s answer: %s exception: %s"
% (prompt, res, str(e)),
flush=True,
)
traceback.print_exc()
score = 0.0
clear_torch_cache()
else:
raise
score_dump.append(ex + [prompt, res, score])
# dump every score in case abort
df_scores = pd.DataFrame(
score_dump,
columns=eval_func_param_names + eval_extra_columns,
)
df_scores.to_parquet(eval_out_filename, index=False)
# plot histogram so far
plt.figure(figsize=(10, 10))
plt.hist(df_scores["score"], bins=20)
score_avg = np.mean(df_scores["score"])
score_median = np.median(df_scores["score"])
print(
"SCORE %s: %s So far: AVG: %s MEDIAN: %s"
% (exi, score, score_avg, score_median),
flush=True,
)
plt.title(
"Score avg: %s median: %s" % (score_avg, score_median)
)
plt.savefig(eval_out_filename.replace(".parquet", ".png"))
plt.close()
print("END" + "=" * 102)
print("")
t2 = time.time()
print(
"Time taken for example: %s Time taken so far: %.4f about %.4g per example"
% (t2 - t1, t2 - t0, (t2 - t0) / (1 + exi))
)
t1 = time.time()
print(
"Total time taken: %.4f about %.4g per example"
% (t1 - t0, (t1 - t0) / num_examples)
)
print(
"Score avg: %s median: %s" % (score_avg, score_median), flush=True
)
return eval_out_filename

View File

@@ -0,0 +1,846 @@
from __future__ import annotations
from typing import (
Any,
Mapping,
Optional,
Dict,
List,
Sequence,
Tuple,
Union,
Protocol,
)
import inspect
import json
import warnings
from pathlib import Path
import yaml
from abc import ABC, abstractmethod
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator, validator
def _get_verbosity() -> bool:
return langchain.verbose
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
input_docs = inputs["input_documents"]
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
if "is_first" in inputs.keys() and not inputs["is_first"]:
run_manager_ = run_manager
input_list = [inputs]
stop = None
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager_:
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
prompt_strings = [p.to_string() for p in prompts]
prompts = prompt_strings
callbacks = run_manager_.get_child() if run_manager_ else None
tags = None
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
params = self.llm.dict()
params["stop"] = stop
options = {"stop": stop}
disregard_cache = self.llm.cache is not None and not self.llm.cache
callback_manager = CallbackManager.configure(
callbacks,
self.llm.callbacks,
self.llm.verbose,
tags,
self.llm.tags,
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.llm.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager_ = callback_manager.on_llm_start(
dumpd(self),
prompts,
invocation_params=params,
options=options,
)
generations = []
for prompt in prompts:
inputs_ = prompt
num_workers = None
batch_size = None
if num_workers is None:
if self.llm.pipeline._num_workers is None:
num_workers = 0
else:
num_workers = self.llm.pipeline._num_workers
if batch_size is None:
if self.llm.pipeline._batch_size is None:
batch_size = 1
else:
batch_size = self.llm.pipeline._batch_size
preprocess_params = {}
generate_kwargs = {}
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self.llm.pipeline._preprocess_params,
**preprocess_params,
}
forward_params = {
**self.llm.pipeline._forward_params,
**forward_params,
}
postprocess_params = {
**self.llm.pipeline._postprocess_params,
**postprocess_params,
}
self.llm.pipeline.call_count += 1
if (
self.llm.pipeline.call_count > 10
and self.llm.pipeline.framework == "pt"
and self.llm.pipeline.device.type == "cuda"
):
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
" dataset",
UserWarning,
)
model_inputs = self.llm.pipeline.preprocess(
inputs_, **preprocess_params
)
model_outputs = self.llm.pipeline.forward(
model_inputs, **forward_params
)
model_outputs["process"] = False
return model_outputs
output = LLMResult(generations=generations)
run_manager_.on_llm_end(output)
if run_manager_:
output.run = RunInfo(run_id=run_manager_.run_id)
response = output
outputs = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
][0]
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
else:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {
k: v for k, v in inputs.items() if k != self.input_key
}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
inputs["is_first"] = False
inputs["input_documents"] = input_docs
# Call predict on the LLM.
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
if "process" in output.keys() and not output["process"]:
return output
output = output[self.llm_chain.output_key]
extra_return_dict = {}
extra_return_dict[self.output_key] = output
outputs = extra_return_dict
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError(
"`run` supports only one positional argument."
)
return self(args[0], callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length.
"""
return None
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
# Call predict on the LLM.
output, extra_return_dict = (
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
self.llm_chain.output_key
],
{},
)
extra_return_dict[self.output_key] = output
return extra_return_dict
from pydantic import BaseModel
class Generation(Serializable):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class LLMChain(Chain):
"""Chain to run queries against LLMs.
Example:
.. code-block:: python
from langchain import LLMChain, OpenAI, PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
response = self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
return self.create_outputs(response)[0]
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list, callbacks=callbacks)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in result
]
else:
return result
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context."""
llm_chain: LLMChain
"""LLM wrapper to use after formatting documents."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
# Format each document according to the prompt
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
return inputs
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Get the prompt length by formatting the prompt."""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
def load_qa_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "map_rerank", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
)

View File

@@ -1,283 +0,0 @@
import os
import json
import shutil
import subprocess
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def do_export():
BASE_MODEL = "h2oai/h2ogpt-oasst1-512-12b"
LORA_WEIGHTS = "h2ogpt-oasst1-512-12b.h2oaih2ogpt-oig-oasst1-instruct-cleaned-v3.1_epochs.805b8e8eff369207340a5a6f90f3c833f9731254.2"
OUTPUT_NAME = "h2ogpt-oig-oasst1-512-12b"
BASE_MODEL = "EleutherAI/pythia-12b-deduped"
LORA_WEIGHTS = "pythia-12b-deduped.h2oaiopenassistant_oasst1_h2ogpt_graded.3_epochs.2ccf687ea3f3f3775a501838e81c1a0066430455.4"
OUTPUT_NAME = "h2ogpt-oasst1-512-12b"
BASE_MODEL = "tiiuae/falcon-40b"
LORA_WEIGHTS = "falcon-40b.h2oaiopenassistant_oasst1_h2ogpt.1_epochs.894d8450d35c180cd03222a45658d04c15b78d4b.9"
OUTPUT_NAME = "h2ogpt-oasst1-2048-falcon-40b"
# BASE_MODEL = 'decapoda-research/llama-65b-hf'
# LORA_WEIGHTS = 'llama-65b-hf.h2oaiopenassistant_oasst1_h2ogpt_graded.1_epochs.113510499324f0f007cbec9d9f1f8091441f2469.3'
# OUTPUT_NAME = "h2ogpt-research-oasst1-llama-65b"
model = os.getenv("MODEL")
# for testing
if model:
BASE_MODEL = "tiiuae/falcon-7b"
LORA_WEIGHTS = model + ".lora"
OUTPUT_NAME = model
llama_type = "llama" in BASE_MODEL
as_pytorch = False # False -> HF
from loaders import get_loaders
model_loader, tokenizer_loader = get_loaders(
model_name=BASE_MODEL, reward_type=False, llama_type=llama_type
)
tokenizer = tokenizer_loader.from_pretrained(
BASE_MODEL,
local_files_only=False,
resume_download=True,
)
tokenizer.save_pretrained(OUTPUT_NAME)
base_model = model_loader(
BASE_MODEL,
load_in_8bit=False,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
print(base_model)
if llama_type:
layers = base_model.model.layers
first_weight = layers[0].self_attn.q_proj.weight
else:
if any(
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
):
layers = base_model.gpt_neox.base_model.layers
first_weight = layers[0].attention.query_key_value.weight
elif any([x in BASE_MODEL.lower() for x in ["falcon"]]):
first_weight = base_model.transformer.h._modules[
"0"
].self_attention.query_key_value.weight
else:
layers = base_model.transformer.base_model.h
first_weight = layers[0].attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
LORA_WEIGHTS,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
assert torch.allclose(first_weight_old, first_weight)
# merge weights TODO: include all lora_target_modules, not just default ones
if llama_type:
lora_model = lora_model.merge_and_unload()
# for layer in lora_model.base_model.model.model.layers:
# layer.self_attn.q_proj.merge_weights = True
# layer.self_attn.k_proj.merge_weights = True
# layer.self_attn.v_proj.merge_weights = True
# layer.self_attn.o_proj.merge_weights = True
else:
if any(
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
):
for layer in lora_model.base_model.gpt_neox.base_model.layers:
layer.attention.query_key_value.merge_weights = True
else:
lora_model.merge_and_unload()
# for layer in lora_model.base_model.transformer.base_model.h:
# layer.attn.q_proj.merge_weights = True
# layer.attn.v_proj.merge_weights = True
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
if as_pytorch:
# FIXME - might not be generic enough still
params = {
"dim": base_model.config.hidden_size,
"n_heads": base_model.config.num_attention_heads,
"n_layers": base_model.config.num_hidden_layers,
"norm_eps": base_model.config.layer_norm_eps,
"vocab_size": base_model.config.vocab_size,
}
n_layers = params["n_layers"]
n_heads = params["n_heads"]
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
)
def permute(w):
return (
w.view(n_heads, dim // n_heads // 2, 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
def unpermute(w):
return (
w.view(n_heads, 2, dim // n_heads // 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
def translate_state_dict_key(k):
if "gpt-neoxt" in BASE_MODEL.lower():
k = k.replace("gpt_neox.model.", "")
else:
k = k.replace("base_model.model.", "")
if k == "model.embed_tokens.weight":
return "tok_embeddings.weight"
elif k == "model.norm.weight":
return "norm.weight"
elif k == "lm_head.weight":
return "output.weight"
elif k.startswith("model.layers."):
layer = k.split(".")[2]
if k.endswith(".self_attn.q_proj.weight"):
return f"layers.{layer}.attention.wq.weight"
elif k.endswith(".self_attn.k_proj.weight"):
return f"layers.{layer}.attention.wk.weight"
elif k.endswith(".self_attn.v_proj.weight"):
return f"layers.{layer}.attention.wv.weight"
elif k.endswith(".self_attn.o_proj.weight"):
return f"layers.{layer}.attention.wo.weight"
elif k.endswith(".mlp.gate_proj.weight"):
return f"layers.{layer}.feed_forward.w1.weight"
elif k.endswith(".mlp.down_proj.weight"):
return f"layers.{layer}.feed_forward.w2.weight"
elif k.endswith(".mlp.up_proj.weight"):
return f"layers.{layer}.feed_forward.w3.weight"
elif k.endswith(".input_layernorm.weight"):
return f"layers.{layer}.attention_norm.weight"
elif k.endswith(".post_attention_layernorm.weight"):
return f"layers.{layer}.ffn_norm.weight"
elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
return None
else:
print(layer, k)
raise NotImplementedError
else:
print(k)
raise NotImplementedError
new_state_dict = {}
for k, v in lora_model_sd.items():
new_k = translate_state_dict_key(k)
if new_k is not None:
if "wq" in new_k or "wk" in new_k:
new_state_dict[new_k] = unpermute(v)
else:
new_state_dict[new_k] = v
os.makedirs("./ckpt", exist_ok=True)
torch.save(new_state_dict, "./ckpt/consolidated.00.pth")
with open("./ckpt/params.json", "w") as f:
json.dump(params, f)
else:
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.config.custom_pipelines = {
"text-generation": {
"impl": "h2oai_pipeline.H2OTextGenerationPipeline",
"pt": "AutoModelForCausalLM",
}
}
PreTrainedModel.save_pretrained(
base_model,
OUTPUT_NAME,
state_dict=deloreanized_sd,
# max_shard_size="5GB",
)
do_copy(OUTPUT_NAME)
test_copy()
def do_copy(OUTPUT_NAME):
dest_file = os.path.join(OUTPUT_NAME, "h2oai_pipeline.py")
shutil.copyfile("src/h2oai_pipeline.py", dest_file)
os.system("""sed -i 's/from enums.*//g' %s""" % dest_file)
os.system("""sed -i 's/from stopping.*//g' %s""" % dest_file)
os.system("""sed -i 's/from prompter.*//g' %s""" % dest_file)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/enums.py", dest_file)
)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/prompter.py", dest_file)
)
os.system(
"""cat %s|grep -v "from enums import PromptType" >> %s"""
% ("src/stopping.py", dest_file)
)
TEST_OUTPUT_NAME = "test_output"
def test_copy():
if os.path.isdir(TEST_OUTPUT_NAME):
shutil.rmtree(TEST_OUTPUT_NAME)
os.makedirs(TEST_OUTPUT_NAME, exist_ok=False)
do_copy(TEST_OUTPUT_NAME)
shutil.copy("src/export_hf_checkpoint.py", TEST_OUTPUT_NAME)
os.environ["DO_COPY_TEST"] = "1"
os.chdir(TEST_OUTPUT_NAME)
output = subprocess.check_output(["python", "export_hf_checkpoint.py"])
print(output)
def inner_test_copy():
"""
pytest -s -v export_hf_checkpoint.py::test_copy
:return:
"""
# test imports
# below supposed to look bad in pycharm, don't fix!
from h2oai_pipeline import (
get_stopping,
get_prompt,
H2OTextGenerationPipeline,
)
assert get_stopping
assert get_prompt
assert H2OTextGenerationPipeline
if __name__ == "__main__":
if os.getenv("DO_COPY_TEST"):
inner_test_copy()
else:
do_export()
# uncomment for raw isolated test, but test is done every time for each export now
# test_copy()

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,7 @@ from enums import (
LangChainMode,
)
from evaluate_params import gen_hyper
from gen import get_model, SEED
from gen import Langchain, SEED
from prompter import non_hf_types, PromptType, Prompter
from utils import (
wrapped_partial,
@@ -44,7 +44,6 @@ from utils import (
makedirs,
get_url,
flatten_list,
get_device,
ProgressParallel,
remove,
hash_file,
@@ -88,10 +87,11 @@ from langchain.document_loaders import (
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
from langchain.chains.question_answering import load_qa_chain
from expanded_pipelines import load_qa_chain
from langchain.docstore.document import Document
from langchain import PromptTemplate, HuggingFaceTextGenInference
from langchain.vectorstores import Chroma
from apps.stable_diffusion.src import args
def get_db(
@@ -371,8 +371,8 @@ def get_embedding(
# to ensure can fork without deadlock
from langchain.embeddings import HuggingFaceEmbeddings
device, torch_dtype, context_class = get_device_dtype()
model_kwargs = dict(device=device)
torch_dtype, context_class = get_dtype()
model_kwargs = dict(device=args.device)
if "instructor" in hf_embedding_model:
encode_kwargs = {"normalize_embeddings": True}
embedding = HuggingFaceInstructEmbeddings(
@@ -436,7 +436,7 @@ class GradioInference(LLM):
chat_client: bool = False
return_full_text: bool = True
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
@@ -481,7 +481,7 @@ class GradioInference(LLM):
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream
stream_output = self.stream_output
gr_client = self.client
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
@@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
tokenizer: Any = None
@@ -663,7 +663,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)
if not self.stream:
if not self.stream_output:
res = self.client.generate(
prompt,
**gen_server_kwargs,
@@ -852,7 +852,7 @@ def get_llm(
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream=stream_output,
stream_output=stream_output,
prompter=prompter,
tokenizer=tokenizer,
client=hf_client,
@@ -907,7 +907,7 @@ def get_llm(
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
inference_server = ""
model, tokenizer, device = get_model(
model, tokenizer, _ = Langchain.get_model(
load_8bit=True,
base_model=model_name,
inference_server=inference_server,
@@ -974,17 +974,15 @@ def get_llm(
return llm, model_name, streamer, prompt_type
def get_device_dtype():
def get_dtype():
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
import torch
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
# from utils import NullContext
# context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
context_class = torch.device
torch_dtype = torch.float16 if device == "cuda" else torch.float32
return device, torch_dtype, context_class
torch_dtype = torch.float16 if args.device == "cuda" else torch.float32
return torch_dtype, context_class
def get_wiki_data(
@@ -1715,7 +1713,7 @@ def path_to_docs(
caption_loader
and not isinstance(caption_loader, (bool, str))
and caption_loader.device != "cpu"
or get_device() == "cuda"
or args.device == "cuda"
):
# to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
n_jobs_image = 1
@@ -2512,8 +2510,7 @@ def _run_qa_db(
formatted_doc_chunks = "\n\n".join(
[get_url(x) + "\n\n" + x.page_content for x in docs]
)
yield formatted_doc_chunks, ""
return
return formatted_doc_chunks, ""
if not docs and langchain_action in [
LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
@@ -2525,8 +2522,7 @@ def _run_qa_db(
else "No documents to summarize."
)
extra = ""
yield ret, extra
return
return ret, extra
if not docs and langchain_mode not in [
LangChainMode.DISABLED.value,
LangChainMode.CHAT_LLM.value,
@@ -2538,8 +2534,7 @@ def _run_qa_db(
else "No documents to query."
)
extra = ""
yield ret, extra
return
return ret, extra
if chain is None and model_name not in non_hf_types:
# here if no docs at all and not HF type
@@ -2549,32 +2544,17 @@ def _run_qa_db(
# context stuff similar to used in evaluate()
import torch
device, torch_dtype, context_class = get_device_dtype()
torch_dtype, context_class = get_dtype()
with torch.no_grad():
have_lora_weights = lora_weights not in [no_lora_str, "", None]
context_class_cast = (
NullContext
if device == "cpu" or have_lora_weights
if args.device == "cpu" or have_lora_weights
else torch.autocast
)
with context_class_cast(device):
with context_class_cast(args.device):
answer = chain()
if not use_context:
ret = answer["output_text"]
extra = ""
yield ret, extra
elif answer is not None:
ret, extra = get_sources_answer(
query,
answer,
scores,
show_rank,
answer_with_sources,
verbose=verbose,
)
yield ret, extra
return
return answer
def get_similarity_chain(
@@ -2960,56 +2940,8 @@ def get_similarity_chain(
template=template,
)
chain = load_qa_chain(llm, prompt=prompt)
else:
# only if use_openai_model = True, unused normally except in testing
chain = load_qa_with_sources_chain(llm)
if not use_context:
chain_kwargs = dict(input_documents=[], question=query)
else:
chain_kwargs = dict(input_documents=docs, question=query)
chain_kwargs = dict(input_documents=docs, question=query)
target = wrapped_partial(chain, chain_kwargs)
elif langchain_action in [
LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_REFINE,
LangChainAction.SUMMARIZE_ALL.value,
]:
from langchain.chains.summarize import load_summarize_chain
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
prompt = PromptTemplate(
input_variables=["text"], template=template
)
chain = load_summarize_chain(
llm,
chain_type="map_reduce",
map_prompt=prompt,
combine_prompt=prompt,
return_intermediate_steps=True,
)
target = wrapped_partial(
chain, {"input_documents": docs}
) # , return_only_outputs=True)
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
assert use_template
prompt = PromptTemplate(
input_variables=["text"], template=template
)
chain = load_summarize_chain(
llm,
chain_type="stuff",
prompt=prompt,
return_intermediate_steps=True,
)
target = wrapped_partial(chain)
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
chain = load_summarize_chain(
llm, chain_type="refine", return_intermediate_steps=True
)
target = wrapped_partial(chain)
else:
raise RuntimeError(
"No such langchain_action=%s" % langchain_action
)
else:
raise RuntimeError("No such langchain_action=%s" % langchain_action)

File diff suppressed because it is too large Load Diff

View File

@@ -1,225 +0,0 @@
from __future__ import annotations
from typing import Iterable
from gradio.themes.soft import Soft
from gradio.themes import Color, Size
from gradio.themes.utils import colors, sizes, fonts
h2o_yellow = Color(
name="yellow",
c50="#fffef2",
c100="#fff9e6",
c200="#ffecb3",
c300="#ffe28c",
c400="#ffd659",
c500="#fec925",
c600="#e6ac00",
c700="#bf8f00",
c800="#a67c00",
c900="#664d00",
c950="#403000",
)
h2o_gray = Color(
name="gray",
c50="#f8f8f8",
c100="#e5e5e5",
c200="#cccccc",
c300="#b2b2b2",
c400="#999999",
c500="#7f7f7f",
c600="#666666",
c700="#4c4c4c",
c800="#333333",
c900="#191919",
c950="#0d0d0d",
)
text_xsm = Size(
name="text_xsm",
xxs="4px",
xs="5px",
sm="6px",
md="7px",
lg="8px",
xl="10px",
xxl="12px",
)
spacing_xsm = Size(
name="spacing_xsm",
xxs="1px",
xs="1px",
sm="1px",
md="2px",
lg="3px",
xl="5px",
xxl="7px",
)
radius_xsm = Size(
name="radius_xsm",
xxs="1px",
xs="1px",
sm="1px",
md="2px",
lg="3px",
xl="5px",
xxl="7px",
)
class H2oTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = h2o_yellow,
secondary_hue: colors.Color | str = h2o_yellow,
neutral_hue: colors.Color | str = h2o_gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Montserrat"),
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
link_text_color="#3344DD",
link_text_color_hover="#3344DD",
link_text_color_visited="#3344DD",
link_text_color_dark="#74abff",
link_text_color_hover_dark="#a3c8ff",
link_text_color_active_dark="#a3c8ff",
link_text_color_visited_dark="#74abff",
button_primary_text_color="*neutral_950",
button_primary_text_color_dark="*neutral_950",
button_primary_background_fill="*primary_500",
button_primary_background_fill_dark="*primary_500",
block_label_background_fill="*primary_500",
block_label_background_fill_dark="*primary_500",
block_label_text_color="*neutral_950",
block_label_text_color_dark="*neutral_950",
block_title_text_color="*neutral_950",
block_title_text_color_dark="*neutral_950",
block_background_fill_dark="*neutral_950",
body_background_fill="*neutral_50",
body_background_fill_dark="*neutral_900",
background_fill_primary_dark="*block_background_fill",
block_radius="0 0 8px 8px",
checkbox_label_text_color_selected_dark="#000000",
)
class SoftTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.indigo,
secondary_hue: colors.Color | str = colors.indigo,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_md,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("Montserrat"),
"ui-sans-serif",
"system-ui",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"Consolas",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
h2o_logo = (
'<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"'
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:'
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" '
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.'
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.'
"82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10"
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"'
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.'
"76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69"
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.'
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.'
"69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30."
"62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8."
"62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-"
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"'
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,'
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
)
def get_h2o_title(title, description):
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
{description}
</div>
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
<h1 style="line-height:60px">{title}</h1>
</div>
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
</div>
"""
def get_simple_title(title, description):
return f"""{description}<h1 align="center"> {title}</h1>"""
def get_dark_js():
return """() => {
if (document.querySelectorAll('.dark').length) {
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
} else {
document.querySelector('body').classList.add('dark');
}
}"""

View File

@@ -1,53 +0,0 @@
def get_css(kwargs) -> str:
if kwargs["h2ocolors"]:
css_code = """footer {visibility: hidden;}
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
body.dark{background:linear-gradient(#000000,#0d0d0d);}
"""
else:
css_code = """footer {visibility: hidden}"""
css_code += make_css_base()
return css_code
def make_css_base() -> str:
return """
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
body.dark{#warning {background-color: #555555};}
#small_btn {
margin: 0.6em 0em 0.55em 0;
max-width: 20em;
min-width: 5em !important;
height: 5em;
font-size: 14px !important;
}
#prompt-form {
border: 1px solid var(--primary-500) !important;
}
#prompt-form.block {
border-radius: var(--block-radius) !important;
}
#prompt-form textarea {
border: 1px solid rgb(209, 213, 219);
}
#prompt-form label > div {
margin-top: 4px;
}
button.primary:hover {
background-color: var(--primary-600) !important;
transition: .2s;
}
#prompt-form-area {
margin-bottom: 2.5rem;
}
.chatsmall chatbot {font-size: 10px !important}
"""

View File

@@ -1,185 +0,0 @@
import os
import math
import gradio as gr
def make_chatbots(output_label0, output_label0_model2, **kwargs):
text_outputs = []
chat_kwargs = []
for model_state_lock in kwargs["model_states"]:
if os.environ.get("DEBUG_MODEL_LOCK"):
model_name = (
model_state_lock["base_model"]
+ " : "
+ model_state_lock["inference_server"]
)
else:
model_name = model_state_lock["base_model"]
output_label = f"h2oGPT [{model_name}]"
min_width = (
250
if kwargs["gradio_size"] in ["small", "large", "medium"]
else 160
)
chat_kwargs.append(
dict(
label=output_label,
visible=kwargs["model_lock"],
elem_classes="chatsmall",
height=kwargs["height"] or 400,
min_width=min_width,
)
)
if kwargs["model_lock_columns"] == -1:
kwargs["model_lock_columns"] = len(kwargs["model_states"])
if kwargs["model_lock_columns"] is None:
kwargs["model_lock_columns"] = 3
ncols = kwargs["model_lock_columns"]
if kwargs["model_states"] == 0:
nrows = 0
else:
nrows = math.ceil(
len(kwargs["model_states"]) / kwargs["model_lock_columns"]
)
if kwargs["model_lock_columns"] == 0:
# not using model_lock
pass
elif nrows <= 1:
with gr.Row():
for chat_kwargs1, model_state_lock in zip(
chat_kwargs, kwargs["model_states"]
):
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == kwargs["model_states"]:
with gr.Row():
for chat_kwargs1, model_state_lock in zip(
chat_kwargs, kwargs["model_states"]
):
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == 2:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= len(kwargs["model_states"]) / 2:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < len(kwargs["model_states"]) / 2:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows == 3:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= 1 * len(kwargs["model_states"]) / 3:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 1 * len(kwargs["model_states"]) / 3
or mii >= 2 * len(kwargs["model_states"]) / 3
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < 2 * len(kwargs["model_states"]) / 3:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
elif nrows >= 4:
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii >= 1 * len(kwargs["model_states"]) / 4:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 1 * len(kwargs["model_states"]) / 4
or mii >= 2 * len(kwargs["model_states"]) / 4
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if (
mii < 2 * len(kwargs["model_states"]) / 4
or mii >= 3 * len(kwargs["model_states"]) / 4
):
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
for mii, (chat_kwargs1, model_state_lock) in enumerate(
zip(chat_kwargs, kwargs["model_states"])
):
if mii < 3 * len(kwargs["model_states"]) / 4:
continue
text_outputs.append(gr.Chatbot(**chat_kwargs1))
with gr.Row():
text_output = gr.Chatbot(
label=output_label0,
visible=not kwargs["model_lock"],
height=kwargs["height"] or 400,
)
text_output2 = gr.Chatbot(
label=output_label0_model2,
visible=False and not kwargs["model_lock"],
height=kwargs["height"] or 400,
)
return text_output, text_output2, text_outputs
def make_prompt_form(kwargs, LangChainMode):
if kwargs["langchain_mode"] != LangChainMode.DISABLED.value:
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
else:
extra_prompt_form = ""
if kwargs["input_lines"] > 1:
instruction_label = (
"Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
)
else:
instruction_label = (
"Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
)
with gr.Row(): # elem_id='prompt-form-area'):
with gr.Column(scale=50):
instruction = gr.Textbox(
lines=kwargs["input_lines"],
label="Ask anything",
placeholder=instruction_label,
info=None,
elem_id="prompt-form",
container=True,
)
with gr.Row():
submit = gr.Button(
value="Submit", variant="primary", scale=0, size="sm"
)
stop_btn = gr.Button(
value="Stop", variant="secondary", scale=0, size="sm"
)
return instruction, submit, stop_btn

View File

@@ -1,13 +1,13 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from io import BytesIO
import torch_mlir
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -20,93 +20,468 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
global_device = "cuda"
global_precision = "fp16"
if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision
tensor_device = "cpu" if args.device == "cpu" else "cuda"
class H2OGPTModel(torch.nn.Module):
def __init__(self, device, precision):
super().__init__()
torch_dtype = (
torch.float32
if precision == "fp32" or device == "cpu"
else torch.float16
)
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
model_kwargs = {
"local_files_only": False,
"torch_dtype": torch_dtype,
"resume_download": True,
"use_auth_token": False,
"trust_remote_code": True,
"offload_folder": "offline_folder",
"device_map": device_map,
}
config = AutoConfig.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
use_auth_token=False,
trust_remote_code=True,
offload_folder="offline_folder",
)
self.model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
config=config,
**model_kwargs,
)
if precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model.transformer.h,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return output.logits[:, -1, :]
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
path_str = (
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
extended_model_name = (
model_name + "_" + args.precision + "_" + args.device
)
vmfb_path = Path(path_str)
path_str = model_name + "_" + args.precision + ".mlir"
mlir_path = Path(path_str)
vmfb_path = Path(extended_model_name + ".vmfb")
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
need_to_compile = False
if not vmfb_path.exists():
if args.device == "cuda" and args.precision in ["fp16", "fp32"]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
need_to_compile = True
# Downloading VMFB from shark_tank
print("Trying to download pre-compiled vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
if vmfb_path.exists():
print(
"Pre-compiled vmfb downloaded from shark tank successfully."
)
need_to_compile = False
if need_to_compile:
if not mlir_path.exists():
print("Trying to download pre-generated mlir from shark tank.")
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + path_str,
vmfb_path.absolute(),
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/"
+ model_name
+ "_"
+ args.precision
+ ".mlir",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(shark_module, vmfb_path, [])
print("Saved newly generated vmfb.")
# Generating the mlir
bytecode = self.get_bytecode(tensor_device, args.precision)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=global_device, mlir_dialect="linalg"
None, device=args.device, mlir_dialect="linalg"
)
shark_module.load_module(vmfb_path)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")
self.model = shark_module
def get_bytecode(self, device, precision):
h2ogpt_model = H2OGPTModel(device, precision)
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 400)
).to(device=device)
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
device=device
)
h2ogptCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
h2ogpt_model,
h2ogptCompileInput,
is_f16=False,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del h2ogpt_model
del self.src_model
print(f"[DEBUG] generating torch mlir")
if precision in ["int4", "int8"]:
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
)
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(
"forward",
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
)
).to(device=global_device)
).to(device=tensor_device)
return result
h2ogpt_model = H2OGPTSHARKModel()
def decode_tokens(tokenizer, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
return res_str
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
del generate_kwargs["max_time"]
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
device=tensor_device
)
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
device=tensor_device
)
truncated_input_ids = []
stopping_criteria = generate_kwargs["stopping_criteria"]
generation_config_ = GenerationConfig.from_model_config(model.config)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids_seq_length = input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
logits_warper = model._get_logits_warper(generation_config)
(
input_ids,
model_kwargs,
) = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
pad_token_id = generation_config.pad_token_id
eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0],
dtype=torch.long,
device=input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
res_tokens = []
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = h2ogpt_shark_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = next_token * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
model_kwargs["past_key_values"] = None
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
truncated_input_ids.append(input_ids[:, 0])
input_ids = input_ids[:, 1:]
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
new_word = tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
res_tokens.append(next_token)
if new_word == "<0x0A>":
print("\n", end="", flush=True)
else:
print(f"{new_word}", end=" ", flush=True)
part_str = decode_tokens(tokenizer, res_tokens)
yield part_str
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_token.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
torch.cuda.empty_cache()
gc.collect()
res_str = decode_tokens(tokenizer, res_tokens)
yield res_str
def pad_or_truncate_inputs(
@@ -118,14 +493,14 @@ def pad_or_truncate_inputs(
num_add_token = max_padding_length - inp_shape[1]
padded_input_ids = torch.cat(
[
torch.tensor([[11] * num_add_token]).to(device=global_device),
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
input_ids,
],
dim=1,
)
padded_attention_mask = torch.cat(
[
torch.tensor([[0] * num_add_token]).to(device=global_device),
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
attention_mask,
],
dim=1,
@@ -319,232 +694,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if global_precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=global_device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=global_device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
@@ -604,32 +753,13 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
elif self.framework == "tf":
from transformers import is_tf_available
if is_tf_available():
import tensorflow as tf
generated_sequence = tf.reshape(
generated_sequence,
(in_b, out_b // in_b, *generated_sequence.shape[1:]),
)
else:
raise ValueError("TF not avaialble.")
return {
"generated_sequence": generated_sequence,
return_dict = {
"model": self.model,
"tokenizer": self.tokenizer,
"input_ids": input_ids,
"prompt_text": prompt_text,
"attention_mask": attention_mask,
"attention_mask": attention_mask,
}
return_dict = {**return_dict, **generate_kwargs}
return return_dict

View File

@@ -1,12 +1,10 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
gradio==3.35.2
huggingface_hub==0.15.1
huggingface_hub==0.16.4
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
torch==2.0.1
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
@@ -19,7 +17,8 @@ matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
transformers==4.30.2
tokenizers==0.13.3
APScheduler==3.10.1
@@ -35,7 +34,7 @@ tensorboard==2.13.0
neptune==1.2.0
# for gradio client
gradio_client==0.2.7
gradio_client==0.2.10
beautifulsoup4==4.12.2
markdown==3.4.3
@@ -45,8 +44,9 @@ pytest-xdist==3.2.1
nltk==3.8.1
textstat==0.7.3
# pandoc==2.3
#pypandoc==1.11
pypandoc_binary==1.11
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
pypandoc_binary==1.11; platform_machine == "x86_64"
pypandoc_binary==1.11; sys_platform == "win32"
openpyxl==3.1.2
lm_dataformat==0.0.20
bioc==2.0
@@ -63,3 +63,58 @@ text-generation==0.6.0
tiktoken==0.4.0
# optional: for OpenAI endpoint or embeddings (requires key)
openai==0.27.8
# optional for chat with PDF
langchain==0.0.329
pypdf==3.17.0
# avoid textract, requires old six
#textract==1.6.5
# for HF embeddings
sentence_transformers==2.2.2
# local vector db
chromadb==0.3.25
# server vector db
#pymilvus==2.2.8
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
# unstructured==0.8.1
# strong support for images
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
unstructured[local-inference]==0.7.4
#pdf2image==1.16.3
#pytesseract==0.3.10
pillow
pdfminer.six==20221105
urllib3
requests_file
#pdf2image==1.16.3
#pytesseract==0.3.10
tabulate==0.9.0
# FYI pandoc already part of requirements.txt
# JSONLoader, but makes some trouble for some users
# jq==1.4.1
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0
# weaviate vector db
weaviate-client==3.22.1
gpt4all==1.0.5
llama-cpp-python==0.1.73
arxiv==1.4.8
pymupdf==1.22.5 # AGPL license
# extract-msg==0.41.1 # GPL3
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
playwright==1.36.0
# requires Chrome binary to be in path
selenium==4.10.0

View File

@@ -1,5 +1,4 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
@@ -202,7 +201,3 @@ def make_db_main(
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

View File

@@ -0,0 +1,442 @@
from pathlib import Path
import argparse
from argparse import RawTextHelpFormatter
import re, gc
"""
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
Following are the various ways this script can be used :-
a. To convert a single Linalg IR to dynamic IR:
--dynamic --first_ir_path=<PATH TO FIRST IR>
b. To convert two Linalg IRs to dynamic IR:
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
c. To combine two Linalg IRs into one:
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
d. To convert both IRs into dynamic as well as combine the IRs:
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
NOTE: For dynamic you'll also need to provide the following set of flags:-
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
--precision (DEFAULT: 'int4')
You may use --save_dynamic to also save the dynamic IR in option d above.
Else for option a. and b. the dynamic IR(s) will get saved by default.
"""
def combine_mlir_scripts(
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
return_ir=True,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
if len(vbody.split(":")) < 2:
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
new_f1, new_f2 = [], []
print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
del new_f1
del new_f2
gc.collect()
print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.
print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()
if return_ir:
print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()
def write_in_dynamic_inputs0(module, dynamic_input_size):
print("[DEBUG] writing dynamic inputs to first vicuna")
# Current solution for ensuring mlir files support dynamic inputs
# TODO: find a more elegant way to implement this
new_lines = []
module = module.splitlines()
while module:
line = module.pop(0)
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
continue
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(module, model_name, precision):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
return "\n".join(new_lines)
def save_dynamic_ir(ir_to_save, output_file):
if not ir_to_save:
return
# We only get string output from the dynamic conversion utility.
from contextlib import redirect_stdout
with open(output_file, "w") as f:
with redirect_stdout(f):
print(ir_to_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama ir utility",
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
+ "\tFollowing are the various ways this script can be used :-\n"
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
+ "\t\tc. To combine two Linalg IRs into one:\n"
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--precision",
"-p",
default="int4",
choices=["fp32", "fp16", "int8", "int4"],
help="Precision of the concerned IR",
)
parser.add_argument(
"--model_name",
type=str,
default="llama2_7b",
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--first_ir_path",
default=None,
help="path to first llama mlir file",
)
parser.add_argument(
"--second_ir_path",
default=None,
help="path to second llama mlir file",
)
parser.add_argument(
"--dynamic_input_size",
type=int,
default=19,
help="Specify the static input size to replace with dynamic dim.",
)
parser.add_argument(
"--dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
parser.add_argument(
"--save_dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Save the individual IR(s) after converting to dynamic",
)
parser.add_argument(
"--combine",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
args, unknown = parser.parse_known_args()
dynamic = args.dynamic
combine = args.combine
assert (
dynamic or combine
), "neither `dynamic` nor `combine` flag is turned on"
first_ir_path = args.first_ir_path
second_ir_path = args.second_ir_path
assert first_ir_path or second_ir_path, "no input ir has been provided"
if combine:
assert (
first_ir_path and second_ir_path
), "you will need to provide both IRs to combine"
precision = args.precision
model_name = args.model_name
dynamic_input_size = args.dynamic_input_size
save_dynamic = args.save_dynamic
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
if dynamic and not combine:
save_dynamic = True
first_ir = None
first_dynamic_ir_name = None
second_ir = None
second_dynamic_ir_name = None
if first_ir_path:
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
with open(first_ir_path, "r") as f:
first_ir = f.read()
if second_ir_path:
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
with open(second_ir_path, "r") as f:
second_ir = f.read()
if dynamic:
first_ir = (
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
if first_ir
else None
)
second_ir = (
write_in_dynamic_inputs1(second_ir, model_name, precision)
if second_ir
else None
)
if save_dynamic:
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
if combine:
combine_mlir_scripts(
first_ir,
second_ir,
f"{model_name}_{precision}.mlir",
return_ir=False,
)

View File

@@ -46,6 +46,7 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -92,7 +93,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
)
print("Saved vmfb at ", str(path))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,673 @@
import torch
from typing import Optional, Tuple
class WordEmbeddingsLayer(torch.nn.Module):
def __init__(self, word_embedding_layer):
super().__init__()
self.model = word_embedding_layer
def forward(self, input_ids):
output = self.model.forward(input=input_ids)
return output
class CompiledWordEmbeddingsLayer(torch.nn.Module):
def __init__(self, compiled_word_embedding_layer):
super().__init__()
self.model = compiled_word_embedding_layer
def forward(self, input_ids):
input_ids = input_ids.detach().numpy()
new_input_ids = self.model("forward", input_ids)
new_input_ids = new_input_ids.reshape(
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
)
return torch.tensor(new_input_ids)
class LNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLNFEmbeddingLayer(torch.nn.Module):
def __init__(self, ln_f):
super().__init__()
self.model = ln_f
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class LMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, embedding_layer):
super().__init__()
self.model = embedding_layer
def forward(self, hidden_states):
output = self.model.forward(input=hidden_states)
return output
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
def __init__(self, lm_head):
super().__init__()
self.model = lm_head
def forward(self, hidden_states):
hidden_states = hidden_states.detach().numpy()
new_hidden_states = self.model("forward", (hidden_states,))
return torch.tensor(new_hidden_states)
class FourWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
return result
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
return result
class TwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
(new_pkv200, new_pkv201),
(new_pkv210, new_pkv211),
(new_pkv220, new_pkv221),
(new_pkv230, new_pkv231),
(new_pkv240, new_pkv241),
(new_pkv250, new_pkv251),
(new_pkv260, new_pkv261),
(new_pkv270, new_pkv271),
(new_pkv280, new_pkv281),
(new_pkv290, new_pkv291),
(new_pkv300, new_pkv301),
(new_pkv310, new_pkv311),
(new_pkv320, new_pkv321),
(new_pkv330, new_pkv331),
(new_pkv340, new_pkv341),
(new_pkv350, new_pkv351),
(new_pkv360, new_pkv361),
(new_pkv370, new_pkv371),
(new_pkv380, new_pkv381),
(new_pkv390, new_pkv391),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
new_pkv200,
new_pkv201,
new_pkv210,
new_pkv211,
new_pkv220,
new_pkv221,
new_pkv230,
new_pkv231,
new_pkv240,
new_pkv241,
new_pkv250,
new_pkv251,
new_pkv260,
new_pkv261,
new_pkv270,
new_pkv271,
new_pkv280,
new_pkv281,
new_pkv290,
new_pkv291,
new_pkv300,
new_pkv301,
new_pkv310,
new_pkv311,
new_pkv320,
new_pkv321,
new_pkv330,
new_pkv331,
new_pkv340,
new_pkv341,
new_pkv350,
new_pkv351,
new_pkv360,
new_pkv361,
new_pkv370,
new_pkv371,
new_pkv380,
new_pkv381,
new_pkv390,
new_pkv391,
)
return result
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
(
torch.tensor(output[41]),
torch.tensor(output[42]),
),
(
torch.tensor(output[43]),
torch.tensor(output[44]),
),
(
torch.tensor(output[45]),
torch.tensor(output[46]),
),
(
torch.tensor(output[47]),
torch.tensor(output[48]),
),
(
torch.tensor(output[49]),
torch.tensor(output[50]),
),
(
torch.tensor(output[51]),
torch.tensor(output[52]),
),
(
torch.tensor(output[53]),
torch.tensor(output[54]),
),
(
torch.tensor(output[55]),
torch.tensor(output[56]),
),
(
torch.tensor(output[57]),
torch.tensor(output[58]),
),
(
torch.tensor(output[59]),
torch.tensor(output[60]),
),
(
torch.tensor(output[61]),
torch.tensor(output[62]),
),
(
torch.tensor(output[63]),
torch.tensor(output[64]),
),
(
torch.tensor(output[65]),
torch.tensor(output[66]),
),
(
torch.tensor(output[67]),
torch.tensor(output[68]),
),
(
torch.tensor(output[69]),
torch.tensor(output[70]),
),
(
torch.tensor(output[71]),
torch.tensor(output[72]),
),
(
torch.tensor(output[73]),
torch.tensor(output[74]),
),
(
torch.tensor(output[75]),
torch.tensor(output[76]),
),
(
torch.tensor(output[77]),
torch.tensor(output[78]),
),
(
torch.tensor(output[79]),
torch.tensor(output[80]),
),
)
return result
class ShardedFalconModel:
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
super().__init__()
self.model = model
self.model.transformer.h = torch.nn.modules.container.ModuleList(
layers
)
self.model.transformer.word_embeddings = word_embeddings
self.model.transformer.ln_f = ln_f
self.model.lm_head = lm_head
def forward(
self,
input_ids,
attention_mask=None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
).logits[:, -1, :]

View File

@@ -0,0 +1,503 @@
import torch
import dataclasses
from enum import auto, Enum
from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class LayerNorm(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class VisionModel(torch.nn.Module):
def __init__(
self,
ln_vision,
visual_encoder,
precision="fp32",
weight_group_size=128,
):
super().__init__()
self.ln_vision = ln_vision
self.visual_encoder = visual_encoder
if precision in ["int4", "int8"]:
print("Vision Model applying weight quantization to ln_vision")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.ln_vision,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print(
"Vision Model applying weight quantization to visual_encoder"
)
quantize_model(
self.visual_encoder,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, image):
image_embeds = self.ln_vision(self.visual_encoder(image))
return image_embeds
class QformerBertModel(torch.nn.Module):
def __init__(self, qformer_bert):
super().__init__()
self.qformer_bert = qformer_bert
def forward(self, query_tokens, image_embeds, image_atts):
query_output = self.qformer_bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
return query_output.last_hidden_state
class FirstLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("First Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, inputs_embeds, position_ids, attention_mask):
print("************************************")
print(
"inputs_embeds: ",
inputs_embeds.shape,
" dtype: ",
inputs_embeds.dtype,
)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("************************************")
config = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"past_key_values": None,
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("Second Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
input_ids,
position_ids,
attention_mask,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
print("************************************")
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
print("past_key_values dtype: ", i1.dtype)
print("************************************")
config = {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
),
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop) :])).item():
return True
return False
CONV_VISION = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)

View File

@@ -0,0 +1,876 @@
import argparse
import json
import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import itertools
import subprocess
import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaPreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
_make_causal_mask,
_expand_mask,
)
from torch import nn
from time import time
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
t1 = time()
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (
seq_length_with_past + past_key_values_length
)
if position_ids is None:
device = (
input_ids.device
if input_ids is not None
else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1:],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
_ = 10
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = tuple(itertools.chain.from_iterable(next_cache))
print(f"Token generated in {time() - t1} seconds")
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(
self,
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
):
pkvs = [
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
return (
output[0],
(output[1][0], output[1][1]),
(output[2][0], output[2][1]),
(output[3][0], output[3][1]),
(output[4][0], output[4][1]),
(output[5][0], output[5][1]),
(output[6][0], output[6][1]),
(output[7][0], output[7][1]),
(output[8][0], output[8][1]),
)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = (
input_ids.device if input_ids is not None else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
pass
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model(
"first_vicuna_forward",
(hidden_states, attention_mask, position_ids),
send_to_host=False,
)
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
print(f"{time() - t1}")
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
# assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -132,7 +132,10 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
try:
hidden_states.detach()
except:
pass
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output

View File

@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -1,4 +1,17 @@
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
from apps.language_models.src.model_wrappers.falcon_sharded_model import (
WordEmbeddingsLayer,
CompiledWordEmbeddingsLayer,
LNFEmbeddingLayer,
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
FourWayShardingDecoderLayer,
TwoWayShardingDecoderLayer,
CompiledFourWayShardingDecoderLayer,
CompiledTwoWayShardingDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
@@ -7,30 +20,39 @@ from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import time
import re
import torch
import torch_mlir
import os
import argparse
import gc
parser = argparse.ArgumentParser(
prog="falcon runner",
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--compressed",
default=False,
action=argparse.BooleanOptionalAction,
help="Do the compression of sharded layers",
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -49,7 +71,7 @@ parser.add_argument(
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
@@ -59,32 +81,74 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
parser.add_argument(
"-s",
"--sharded",
default=False,
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
parser.add_argument(
"--num_shards",
type=int,
default=4,
choices=[2, 4],
help="Number of shards.",
)
class Falcon(SharkLLMBase):
class ShardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if (
"180b" in self.model_name
and precision != "int4"
and hf_auth_token == None
):
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
if args.sharded and "180b" not in self.model_name:
raise ValueError("Sharding supported only for Falcon-180B")
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, trust_remote_code=True
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
@@ -92,13 +156,536 @@ class Falcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
kwargs = {
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return falcon_model
def compile_falcon(self):
def compile_layer(
self, layer, falconCompileInput, layer_id, device_idx=None
):
self.falcon_mlir_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
)
self.falcon_vmfb_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb"
)
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
print(f"[DEBUG] Trying to download vmfb from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/"
+ str(self.falcon_vmfb_path),
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=device_idx,
)
if vmfb is not None:
return vmfb, device_idx
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/"
+ str(self.falcon_mlir_path),
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
if layer_id == "word_embeddings":
f16_input_mask = [False]
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
else:
raise ValueError("Unsupported layer: ", layer_id)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
layer,
falconCompileInput,
is_f16=True,
f16_input_mask=f16_input_mask,
mlir_type="torchscript",
is_gptq=True,
)
del layer
print(f"[DEBUG] generating torch mlir")
module = torch_mlir.compile(
ts_graph,
falconCompileInput,
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
device_idx=device_idx,
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module, device_idx
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool)
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, 14848], dtype=torch.float32
)
# Determine number of available devices
num_devices = 1
if self.device == "rocm":
import iree.runtime as ireert
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if num_devices < 2:
raise ValueError(
"Cannot run Falcon-180B on a single ROCM device."
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
shark_lm_head, _ = self.compile_layer(
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=(0 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
word_embedding = WordEmbeddingsLayer(
self.src_model.transformer.word_embeddings
)
print("Compiling Layer word_embeddings")
shark_word_embedding, _ = self.compile_layer(
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=(1 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
)
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
print("Compiling Layer ln_f")
shark_ln_f, _ = self.compile_layer(
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=(2 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
shark_layers = []
for i in range(
int(len(self.src_model.transformer.h) / num_group_layers)
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = FourWayShardingDecoderLayer
compiled_class = CompiledFourWayShardingDecoderLayer
if args.num_shards == 2:
pytorch_class = TwoWayShardingDecoderLayer
compiled_class = CompiledTwoWayShardingDecoderLayer
print("Compiling Layer {}".format(layer_id))
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
)
shark_module, device_idx = self.compile_layer(
pytorch_layer_i,
[sample_hidden_states, sample_attention_mask],
layer_id,
device_idx=device_idx,
)
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
shark_module,
)
shark_layers.append(shark_layer_i)
sharded_model = ShardedFalconModel(
self.src_model,
shark_layers,
shark_word_embedding,
shark_ln_f,
shark_lm_head,
)
return sharded_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_padding_length,
add_special_tokens=False,
return_tensors="pt",
)
model_inputs["prompt_text"] = prompt
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
generate_kwargs = {
"max_length": self.max_num_tokens,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
"eos_token_id": 11,
}
generate_kwargs["input_ids"] = input_ids
generate_kwargs["attention_mask"] = attention_mask
generation_config_ = GenerationConfig.from_model_config(
self.src_model.config
)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = StoppingCriteriaList()
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
self.logits_processor = self.src_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.src_model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
self.logits_warper = self.src_model._get_logits_warper(
generation_config
)
(
self.input_ids,
self.model_kwargs,
) = self.src_model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id) if eos_token_id is not None else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
print(f"{all_text}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(input_ids, self.scores)
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()
return all_text
def generate_new_token(self):
model_inputs = self.src_model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = self.shark_model.forward(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
)
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
class UnshardedFalcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = "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",
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
return tokenizer
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
@@ -120,37 +707,37 @@ class Falcon(SharkLLMBase):
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
@@ -167,9 +754,10 @@ class Falcon(SharkLLMBase):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
@@ -189,35 +777,37 @@ class Falcon(SharkLLMBase):
bytecode = bytecode_stream.getvalue()
del module
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
@@ -345,7 +935,11 @@ class Falcon(SharkLLMBase):
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
@@ -372,6 +966,13 @@ class Falcon(SharkLLMBase):
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()
@@ -387,7 +988,7 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
@@ -466,18 +1067,39 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
import gc
if not args.sharded:
falcon = UnshardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
falcon_vmfb_path=falcon_vmfb_path,
)
else:
falcon = ShardedFalcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
device=args.device,
precision=args.precision,
)
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
@@ -497,7 +1119,11 @@ if __name__ == "__main__":
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
torch.cuda.empty_cache()
gc.collect()
print(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
return cls()
def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)
return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)

View File

@@ -0,0 +1,5 @@
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/

View File

@@ -0,0 +1,33 @@
model:
arch: mini_gpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True
# Q-Former
num_query_token: 32
# Vicuna
llama_model: "lmsys/vicuna-7b-v1.3"
# generation configs
prompt: ""
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@@ -0,0 +1,25 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'prerained_minigpt4_7b.pth'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

View File

@@ -0,0 +1,629 @@
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
import requests
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(
torch.meshgrid([coords_h, coords_w])
) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += (
window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.gamma_1
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (
img_size[0] // patch_size[0]
)
self.patch_shape = (
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=None,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_mean_pooling=True,
init_scale=0.001,
use_checkpoint=False,
):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads
)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.patch_shape
if use_rel_pos_bias
else None,
)
for i in range(depth)
]
)
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
# l.weight.data = l.weight.data.half()
l.weight.data = l.weight.data
if l.bias is not None:
# l.bias.data = l.bias.data.half()
l.bias.data = l.bias.data
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408 // 88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
local_filename = "eva_vit_g.pth"
response = requests.get(url)
if response.status_code == 200:
with open(local_filename, "wb") as f:
f.write(response.content)
print("File downloaded successfully.")
state_dict = torch.load(local_filename, map_location="cpu")
interpolate_pos_embed(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model

View File

@@ -0,0 +1,4 @@
<Img><ImageHere></Img> Describe this image in detail.
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
<Img><ImageHere></Img> Please provide a detailed description of the picture.
<Img><ImageHere></Img> Could you describe the contents of this image for me?

View File

@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
)
print("Saved vmfb at ", str(path))

View File

@@ -3,11 +3,12 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
@@ -17,9 +18,31 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
return None
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
def get_vmfb_from_config(
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
)
if padding:
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)

View File

@@ -0,0 +1,91 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch
llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
}
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def chat(self, prompt):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -0,0 +1,14 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,428 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
from ui.chat import chat_element
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)

View File

View File

@@ -0,0 +1,517 @@
import gradio as gr
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
language_model = None
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def get_default_config():
return False
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global language_model
if language_model is None:
language_model = LanguageModel(
model, device=device, precision=precision
)
language_model.chat(prompt_prefix)
return "", ""
global past_key_values
global model_vmfb_key
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
```

View File

@@ -34,7 +34,7 @@ from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
@@ -287,7 +287,7 @@ def lora_train(
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)

View File

@@ -7,12 +7,16 @@ import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [".", "./apps/language_models/langchain"]
pathex = [
".",
"./apps/language_models/langchain",
"./apps/language_models/src/pipelines/minigpt4_utils",
]
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
@@ -26,25 +30,30 @@ datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += collect_data_files("einops")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
@@ -52,6 +61,14 @@ datas += [
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",
"minigpt4_utils/configs",
),
(
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
"minigpt4_utils/prompts",
),
]
@@ -59,6 +76,13 @@ datas += [
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -8,6 +8,7 @@ import traceback
import subprocess
import sys
import os
import requests
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
get_civitai_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
# needed to ensure webui sets the correct model name metadata
args.ckpt_loc = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -177,9 +178,11 @@ class SharkifyStableDiffusionModel:
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
]
index = 0
for model in sub_model_list:
@@ -339,7 +342,7 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_controlled_unet(self):
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self,
@@ -415,6 +418,16 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
model_name = "stencil_unet"
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[:2]
+ (torch.nn.functional.pad(inputs[2], pad),)
+ inputs[3:]
)
model_name = "stencil_unet_512"
input_mask = [
True,
True,
@@ -437,19 +450,19 @@ class SharkifyStableDiffusionModel:
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["stencil_unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self):
def get_control_net(self, use_large=False):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
@@ -497,17 +510,34 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name["stencil_adaptor"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="stencil_adaptor",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -681,8 +711,11 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,
@@ -748,7 +781,7 @@ class SharkifyStableDiffusionModel:
else:
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
return self.get_controlled_unet(use_large=use_large)
def vae_encode(self):
try:
@@ -847,12 +880,14 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self):
def controlnet(self, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:

View File

@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
from apps.stable_diffusion.src.utils import (
resamplers,
resampler_list,
)
class Image2ImagePipeline(StableDiffusionPipeline):
@@ -84,13 +88,21 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((width, height))
# Pre-process image
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
@@ -147,6 +159,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -186,6 +199,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
# Get Image latents

View File

@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
def load_controlnet(self):
if self.controlnet is not None:
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def prepare_latents(
self,
batch_size,
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
self.load_controlnet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -135,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -191,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -218,6 +273,7 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.

View File

@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()

View File

@@ -28,6 +28,7 @@ from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
parse_seed_input,
batch_seeds,
get_path_stem,
get_extended_name,
@@ -40,3 +41,8 @@ from apps.stable_diffusion.src.utils.utils import (
resize_stencil,
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
from apps.stable_diffusion.src.utils.resamplers import (
resamplers,
resampler_list,
)

View File

@@ -0,0 +1,42 @@
import re
import requests
from apps.stable_diffusion.src.utils.stable_args import args
from pathlib import Path
from tqdm import tqdm
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = (
Path.cwd() / (args.ckpt_dir or "models") / base_filename
)
# we don't have this model downloaded yet
if not destination_path.is_file():
print(
f"downloading civitai model from {url} to {destination_path}"
)
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -3,7 +3,9 @@ from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
from shark.parser import shark_args
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")

View File

@@ -0,0 +1,12 @@
import PIL.Image as Image
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()

View File

@@ -11,12 +11,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -28,7 +28,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
@@ -37,7 +37,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
if use_winograd:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
else:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)

View File

@@ -2,6 +2,8 @@ import argparse
import os
from pathlib import Path
from apps.stable_diffusion.src.utils.resamplers import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
@@ -66,9 +68,9 @@ p.add_argument(
p.add_argument(
"--seed",
type=int,
type=str,
default=-1,
help="The seed to use. -1 for a random one.",
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
@@ -132,6 +134,47 @@ p.add_argument(
"img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
@@ -202,28 +245,30 @@ p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting.",
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting.",
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting.",
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting.",
help="If extend bottom for outpainting.",
)
p.add_argument(
@@ -255,7 +300,7 @@ p.add_argument(
p.add_argument(
"--import_mlir",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
@@ -278,7 +323,7 @@ p.add_argument(
p.add_argument(
"--use_tuned",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
@@ -371,7 +416,7 @@ p.add_argument(
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
@@ -400,6 +445,21 @@ p.add_argument(
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -418,27 +478,6 @@ p.add_argument(
help="Specify target triple for metal.",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info.",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for disabling vulkan validation layers when benchmarking.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
@@ -533,6 +572,20 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
@@ -582,6 +635,25 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
@@ -659,6 +731,18 @@ p.add_argument(
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(

View File

@@ -1,2 +1,3 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector

View File

@@ -1,14 +1,46 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
stencil = {}
def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
subdir = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
@@ -117,6 +149,9 @@ def controlnet_hint_conversion(
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
controlnet_hint = hint_zoedepth(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
@@ -127,7 +162,7 @@ def controlnet_hint_conversion(
stencil_to_model_id_map = {
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
@@ -157,6 +192,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -172,6 +208,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -183,4 +220,19 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map
# Stencil 4. Depth (Only Zoe Preprocessing)
def hint_zoedepth(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
if not "depth" in stencil:
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -0,0 +1,58 @@
import numpy as np
import torch
from pathlib import Path
import requests
from einops import rearrange
remote_model_path = (
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
)
class ZoeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
modelpath = ckpt_path / "ZoeD_M12_N.pt"
with requests.get(remote_model_path, stream=True) as r:
r.raise_for_status()
with open(modelpath, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
model = torch.hub.load(
"monorimet/ZoeDepth:torch_update",
"ZoeD_N",
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
model.eval()
self.model = model
def __call__(self, input_image):
assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float()
image_depth = image_depth / 255.0
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
depth = self.model.infer(image_depth)
depth = depth[0, 0].cpu().numpy()
vmin = np.percentile(depth, 2)
vmax = np.percentile(depth, 85)
depth -= vmin
depth /= vmax - vmin
depth = 1.0 - depth
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
return depth_image

View File

@@ -18,13 +18,14 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
@@ -77,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
os.getcwd(), model_name, extra_args, debug=args.compile_debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -117,7 +118,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,
@@ -153,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -167,6 +168,14 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -178,20 +187,22 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if args.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
@@ -461,18 +472,43 @@ def get_available_devices():
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
device_list.append(f"{device_name} => {driver_name}://{i}")
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(
f"{device_name} => {driver_name}://{i}"
)
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
@@ -496,10 +532,15 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "rocm" in args.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
print(iree_flags)
if args.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -563,7 +604,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
@@ -727,7 +768,8 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed):
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
@@ -735,20 +777,49 @@ def sanitize_seed(seed):
return seed
# Generate a set of seeds, using as the first seed of the set,
# optionally using it as the rng seed for subsequent seeds in the set
def batch_seeds(seed, batch_count, repeatable=False):
# use the passed seed as the initial seed of the batch
seeds = [sanitize_seed(seed)]
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(
type(seed) is int for seed in seed_input
):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(
seed_input: str | list | int, batch_count: int, repeatable=False
):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
# use the initial seed as the rng generator seed
saved_random_state = random_getstate()
seed_random(seed)
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# generate the additional seeds
for i in range(1, batch_count):
seeds.append(sanitize_seed(-1))
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
@@ -784,6 +855,8 @@ def clear_all():
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
if args.local_tank_cache != "":
shutil.rmtree(args.local_tank_cache)
def get_generated_imgs_path() -> Path:
@@ -829,6 +902,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if args.use_hiresfix:
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
else:
png_size_text = f"{args.width}x{args.height}"
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}"
@@ -837,7 +917,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
@@ -864,8 +944,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"HEIGHT": args.height
if not args.use_hiresfix
else args.hiresfix_height,
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
@@ -903,6 +985,10 @@ def get_generation_text_info(seeds, device):
)
text_output += (
f"\nsize={args.height}x{args.width}, "
if not args.use_hiresfix
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
)
text_output += (
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"

View File

@@ -0,0 +1 @@
from apps.stable_diffusion.web.api.sdapi_v1 import sdapi

View File

@@ -0,0 +1,579 @@
import os
from collections import defaultdict
from enum import Enum
from fastapi import FastAPI
from pydantic import BaseModel, Field, conlist, model_validator
from apps.stable_diffusion.web.api.utils import (
frozen_args,
sampler_aliases,
encode_pil_to_base64,
decode_base64_to_image,
get_model_from_request,
get_scheduler_from_request,
get_lora_params,
get_device,
GenerationInputData,
GenerationResponseData,
)
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_files,
get_custom_model_pathfile,
predefined_models,
predefined_paint_models,
predefined_upscaler_models,
scheduler_list,
)
from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf
from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf
from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf
from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf
from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf
sdapi = FastAPI()
# Rest API: /sdapi/v1/sd-models (lists available models)
class AppParam(str, Enum):
txt2img = "txt2img"
img2img = "img2img"
inpaint = "inpaint"
outpaint = "outpaint"
upscaler = "upscaler"
@sdapi.get(
"/v1/sd-models",
summary="lists available models",
description=(
"This is all the models that this server currently knows about.\n "
"Models listed may still have a compilation and build pending that "
"will be triggered the first time they are used."
),
)
def sd_models_api(app: AppParam = frozen_args.app):
match app:
case "inpaint" | "outpaint":
checkpoint_type = "inpainting"
predefined = predefined_paint_models
case "upscaler":
checkpoint_type = "upscaler"
predefined = predefined_upscaler_models
case _:
checkpoint_type = ""
predefined = predefined_models
return [
{
"title": model_file,
"model_name": model_file,
"hash": None,
"sha256": None,
"filename": get_custom_model_pathfile(model_file),
"config": None,
}
for model_file in get_custom_model_files(
custom_checkpoint_type=checkpoint_type
)
] + [
{
"title": model,
"model_name": model,
"hash": None,
"sha256": None,
"filename": None,
"config": None,
}
for model in predefined
]
# Rest API: /sdapi/v1/samplers (lists schedulers)
@sdapi.get(
"/v1/samplers",
summary="lists available schedulers/samplers",
description=(
"These are all the Schedulers defined and available. Not "
"every scheduler is compatible with all apis. Aliases are "
"equivalent samplers in A1111 if they are known."
),
)
def sd_samplers_api():
reverse_sampler_aliases = defaultdict(list)
for key, value in sampler_aliases.items():
reverse_sampler_aliases[value].append(key)
return (
{
"name": scheduler,
"aliases": reverse_sampler_aliases.get(scheduler, []),
"options": {},
}
for scheduler in scheduler_list
)
# Rest API: /sdapi/v1/options (lists application level options)
@sdapi.get(
"/v1/options",
summary="lists current settings of application level options",
description=(
"A subset of the command line arguments set at startup renamed "
"to correspond to the A1111 naming. Only a small subset of A1111 "
"options are returned."
),
)
def options_api():
# This is mostly just enough to support what Koboldcpp wants, with a
# few other things that seemed obvious
return {
"samples_save": True,
"samples_format": frozen_args.output_img_format,
"sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc)
if frozen_args.ckpt_loc
else frozen_args.hf_model_id,
"sd_lora": frozen_args.use_lora,
"sd_vae": frozen_args.custom_vae or "Automatic",
"enable_pnginfo": frozen_args.write_metadata_to_png,
}
# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings)
@sdapi.get(
"/v1/cmd-flags",
summary="lists the command line arguments value that were set on startup.",
)
def cmd_flags_api():
return vars(frozen_args)
# Rest API: /sdapi/v1/txt2img (Text to image)
class ModelOverrideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
fallback_model="stabilityai/stable-diffusion-2-1-base"
)
class Txt2ImgInputData(GenerationInputData):
enable_hr: bool = frozen_args.use_hiresfix
hr_resize_y: int = Field(
default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8
)
hr_resize_x: int = Field(
default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8
)
override_settings: ModelOverrideSettings = None
@sdapi.post(
"/v1/txt2img",
summary="Does text to image generation",
response_model=GenerationResponseData,
)
def txt2img_api(InputData: Txt2ImgInputData):
model_id = get_model_from_request(
InputData,
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
)
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed},"
f"Model: {model_id}, "
f"Scheduler: {scheduler}. "
)
res = txt2img_inf(
InputData.prompt,
InputData.negative_prompt,
InputData.height,
InputData.width,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
hiresfix_height=InputData.hr_resize_y,
hiresfix_width=InputData.hr_resize_x,
hiresfix_strength=frozen_args.hiresfix_strength,
resample_type=frozen_args.resample_type,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/img2img (Image to image)
class StencilParam(str, Enum):
canny = "canny"
openpose = "openpose"
scribble = "scribble"
zoedepth = "zoedepth"
class Img2ImgInputData(GenerationInputData):
init_images: conlist(str, min_length=1, max_length=2)
denoising_strength: float = frozen_args.strength
use_stencil: StencilParam = frozen_args.use_stencil
override_settings: ModelOverrideSettings = None
@model_validator(mode="after")
def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData":
if (
self.use_stencil == StencilParam.scribble
and len(self.init_images) < 2
):
raise ValueError(
"a second image must be supplied for the controlnet:scribble stencil"
)
return self
@sdapi.post(
"/v1/img2img",
summary="Does image to image generation",
response_model=GenerationResponseData,
)
def img2img_api(
InputData: Img2ImgInputData,
):
model_id = get_model_from_request(
InputData,
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(InputData, "img2img")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
mask_image = (
decode_base64_to_image(InputData.init_images[1])
if len(InputData.init_images) > 1
else None
)
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = img2img_inf(
InputData.prompt,
InputData.negative_prompt,
{"image": init_image, "mask": mask_image},
InputData.height,
InputData.width,
InputData.steps,
InputData.denoising_strength,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
use_stencil=InputData.use_stencil,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/inpaint (Inpainting)
class PaintModelOverideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
class InpaintInputData(GenerationInputData):
image: str = Field(description="Base64 encoded input image")
mask: str = Field(description="Base64 encoded mask image")
is_full_res: bool = False # Is this setting backwards in the UI?
full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4)
denoising_strength: float = frozen_args.strength
use_stencil: StencilParam = frozen_args.use_stencil
override_settings: PaintModelOverideSettings = None
@sdapi.post(
"/v1/inpaint",
summary="Does inpainting generation on an image",
response_model=GenerationResponseData,
)
def inpaint_api(
InputData: InpaintInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "inpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.image)
mask = decode_base64_to_image(InputData.mask)
print(
f"Prompt: {InputData.prompt}, "
f'Negative Prompt: {InputData.negative_prompt}", '
f'Seed: {InputData.seed}", '
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = inpaint_inf(
InputData.prompt,
InputData.negative_prompt,
{"image": init_image, "mask": mask},
InputData.height,
InputData.width,
InputData.is_full_res,
InputData.full_res_padding,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/outpaint (Outpainting)
class DirectionParam(str, Enum):
left = "left"
right = "right"
up = "up"
down = "down"
class OutpaintInputData(GenerationInputData):
init_images: list[str]
pixels: int = Field(
default=frozen_args.pixels, ge=8, le=256, multiple_of=8
)
mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64)
directions: set[DirectionParam] = [
direction
for direction in ["left", "right", "up", "down"]
if vars(frozen_args)[direction]
]
noise_q: float = frozen_args.noise_q
color_variation: float = frozen_args.color_variation
override_settings: PaintModelOverideSettings = None
@sdapi.post(
"/v1/outpaint",
summary="Does outpainting generation on an image",
response_model=GenerationResponseData,
)
def outpaint_api(
InputData: OutpaintInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="inpainting",
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "outpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = outpaint_inf(
InputData.prompt,
InputData.negative_prompt,
init_image,
InputData.pixels,
InputData.mask_blur,
InputData.directions,
InputData.noise_q,
InputData.color_variation,
InputData.height,
InputData.width,
InputData.steps,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}
# Rest API: /sdapi/v1/upscaler (Upscaling)
class UpscalerModelOverideSettings(BaseModel):
sd_model_checkpoint: str = get_model_from_request(
checkpoint_type="upscaler",
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
class UpscalerInputData(GenerationInputData):
init_images: list[str] = Field(
description="Base64 encoded image to upscale"
)
noise_level: int = frozen_args.noise_level
override_settings: UpscalerModelOverideSettings = None
@sdapi.post(
"/v1/upscaler",
summary="Does image upscaling",
response_model=GenerationResponseData,
)
def upscaler_api(
InputData: UpscalerInputData,
):
model_id = get_model_from_request(
InputData,
checkpoint_type="upscaler",
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
scheduler = get_scheduler_from_request(InputData, "upscaler")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed}, "
f"Model: {model_id}, "
f"Scheduler: {scheduler}."
)
res = upscaler_inf(
InputData.prompt,
InputData.negative_prompt,
init_image,
InputData.height,
InputData.width,
InputData.steps,
InputData.noise_level,
InputData.cfg_scale,
InputData.seed,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
model_id=model_id,
custom_vae=frozen_args.custom_vae or "None",
precision="fp16",
device=get_device(frozen_args.device),
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}

View File

@@ -0,0 +1,211 @@
import base64
import pickle
from argparse import Namespace
from fastapi.exceptions import HTTPException
from io import BytesIO
from PIL import Image
from pydantic import BaseModel, Field
from apps.stable_diffusion.src import args
from apps.stable_diffusion.web.ui.utils import (
available_devices,
get_custom_model_files,
predefined_models,
predefined_paint_models,
predefined_upscaler_models,
scheduler_list,
scheduler_list_cpu_only,
)
# Probably overly cautious, but try to ensure we only use the starting
# args in each api call, as the code does `args.<whatever> = <changed_value>`
# in lots of places and in testing, it seemed to me, these changes leaked
# into subsequent api calls.
# Roundtripping through pickle for deepcopy, there is probably a better way
frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args)))))
# an attempt to map some of the A1111 sampler names to scheduler names
# https://github.com/huggingface/diffusers/issues/4167 is where the
# (not so obvious) ones come from
sampler_aliases = {
# a1111/onnx (these point to diffusers classes in A1111)
"pndm": "PNDM",
"heun": "HeunDiscrete",
"ddim": "DDIM",
"ddpm": "DDPM",
"euler": "EulerDiscrete",
"euler-ancestral": "EulerAncestralDiscrete",
"dpm": "DPMSolverMultistep",
# a1111/k_diffusion (the obvious ones)
"Euler a": "EulerAncestralDiscrete",
"Euler": "EulerDiscrete",
"LMS": "LMSDiscrete",
"Heun": "HeunDiscrete",
# a1111/k_diffusion (not so obvious)
"DPM++ 2M": "DPMSolverMultistep",
"DPM++ 2M Karras": "DPMSolverMultistepKarras",
"DPM++ 2M SDE": "DPMSolverMultistep++",
"DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++",
"DPM2": "KDPM2Discrete",
"DPM2 a": "KDPM2AncestralDiscrete",
}
allowed_schedulers = {
"txt2img": {
"schedulers": scheduler_list,
"fallback": "SharkEulerDiscrete",
},
"txt2img_hires": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DEISMultistep",
},
"img2img": {
"schedulers": scheduler_list_cpu_only,
"fallback": "EulerDiscrete",
},
"inpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"outpaint": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
"upscaler": {
"schedulers": scheduler_list_cpu_only,
"fallback": "DDIM",
},
}
# base pydantic model for sd generation apis
class GenerationInputData(BaseModel):
prompt: str = ""
negative_prompt: str = ""
hf_model_id: str | None = None
height: int = Field(
default=frozen_args.height, ge=128, le=768, multiple_of=8
)
width: int = Field(
default=frozen_args.width, ge=128, le=768, multiple_of=8
)
sampler_name: str = frozen_args.scheduler
cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1)
steps: int = Field(default=frozen_args.steps, ge=1, le=100)
seed: int = frozen_args.seed
n_iter: int = Field(default=frozen_args.batch_count)
class GenerationResponseData(BaseModel):
images: list[str] = Field(description="Generated images, Base64 encoded")
properties: dict = {}
info: str
# image encoding/decoding
def encode_pil_to_base64(images: list[Image.Image]):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if frozen_args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
def decode_base64_to_image(encoding: str):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=400, detail="Invalid encoded image")
# get valid sd models/vaes/schedulers etc.
def get_predefined_models(custom_checkpoint_type: str):
match custom_checkpoint_type:
case "inpainting":
return predefined_paint_models
case "upscaler":
return predefined_upscaler_models
case _:
return predefined_models
def get_model_from_request(
request_data=None,
checkpoint_type: str = "",
fallback_model: str = "",
):
model = None
if request_data:
if request_data.hf_model_id:
model = request_data.hf_model_id
elif request_data.override_settings:
model = request_data.override_settings.sd_model_checkpoint
# if the request didn't specify a model try the command line args
result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id
# make sure whatever we have is a valid model for the checkpoint type
if result in get_custom_model_files(
custom_checkpoint_type=checkpoint_type
) + get_predefined_models(checkpoint_type):
return result
# if not return what was specified as the fallback
else:
return fallback_model
def get_scheduler_from_request(
request_data: GenerationInputData, operation: str
):
allowed = allowed_schedulers[operation]
requested = request_data.sampler_name
requested = sampler_aliases.get(requested, requested)
return (
requested
if requested in allowed["schedulers"]
else allowed["fallback"]
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched
return next(
(device for device in available_devices if device_str in device),
available_devices[0],
)

View File

@@ -1,6 +1,8 @@
from multiprocessing import Process, freeze_support
from multiprocessing import freeze_support
import os
import sys
import logging
import apps.stable_diffusion.web.utils.app as app
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -20,78 +22,71 @@ if args.clear_all:
clear_all()
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
from apps.stable_diffusion.web.ui import (
txt2img_api,
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
llm_chat_api,
)
from apps.stable_diffusion.web.api import sdapi
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# init global sd pipeline and config
global_obj._init()
app = FastAPI()
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
api = FastAPI()
api.mount("/sdapi/", sdapi)
# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
api.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
api.add_api_route("/completions", llm_chat_api, methods=["post"])
api.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
api.include_router(APIRouter())
# deal with CORS requests if CORS accept origins are set
if args.api_accept_origin:
print(
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
)
api.add_middleware(
CORSMiddleware,
allow_origins=args.api_accept_origin,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
else:
print("API not configured for CORS")
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
config_gradio_tmp_imgs_folder,
from apps.stable_diffusion.web.utils.tmp_configs import (
config_tmp,
)
config_gradio_tmp_imgs_folder()
config_tmp()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
from apps.stable_diffusion.web.ui.utils import (
create_custom_models_folders,
nodicon_loc,
)
create_custom_models_folders()
@@ -107,7 +102,6 @@ if __name__ == "__main__":
from apps.stable_diffusion.web.ui import (
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -115,10 +109,10 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
h2ogpt_web,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -127,7 +121,6 @@ if __name__ == "__main__":
img2img_sendto_upscaler,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -136,7 +129,6 @@ if __name__ == "__main__":
inpaint_sendto_upscaler,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -145,15 +137,15 @@ if __name__ == "__main__":
outpaint_sendto_upscaler,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
lora_train_web,
model_web,
# lora_train_web,
# model_web,
# model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -161,6 +153,7 @@ if __name__ == "__main__":
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
minigpt4_web,
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
@@ -207,9 +200,18 @@ if __name__ == "__main__":
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
@@ -220,14 +222,8 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
@@ -241,8 +237,31 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="DocuChat(Experimental)", id=9):
h2ogpt_web.render()
# with gr.TabItem(label="Model Manager", id=6):
# model_web.render()
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:
sd_web.load(
fn=lambda: gr.Info(
f"Port {args.server_port} is in use by another application. "
f"Shark is running on port {actual_port} instead."
)
)
# send to buttons
register_button_click(
@@ -376,42 +395,38 @@ if __name__ == "__main__":
modelmanager_sendto_txt2img,
0,
[hf_models],
[txt2img_custom_model, txt2img_hf_model_id, tabs],
[txt2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_img2img,
1,
[hf_models],
[img2img_custom_model, img2img_hf_model_id, tabs],
[img2img_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_inpaint,
2,
[hf_models],
[inpaint_custom_model, inpaint_hf_model_id, tabs],
[inpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_outpaint,
3,
[hf_models],
[outpaint_custom_model, outpaint_hf_model_id, tabs],
[outpaint_custom_model, tabs],
)
register_modelmanager_button(
modelmanager_sendto_upscaler,
4,
[hf_models],
[upscaler_custom_model, upscaler_hf_model_id, tabs],
[upscaler_custom_model, tabs],
)
sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",
inbrowser=not app.launch(actual_port),
server_name="0.0.0.0",
server_port=args.server_port,
server_port=actual_port,
favicon_path=nodicon_loc,
)

View File

@@ -1,9 +1,7 @@
from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_inf,
txt2img_api,
txt2img_web,
txt2img_custom_model,
txt2img_hf_model_id,
txt2img_gallery,
txt2img_png_info_img,
txt2img_status,
@@ -14,10 +12,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_api,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
img2img_gallery,
img2img_init_image,
img2img_status,
@@ -27,10 +23,8 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
)
from apps.stable_diffusion.web.ui.inpaint_ui import (
inpaint_inf,
inpaint_api,
inpaint_web,
inpaint_custom_model,
inpaint_hf_model_id,
inpaint_gallery,
inpaint_init_image,
inpaint_status,
@@ -40,10 +34,8 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
)
from apps.stable_diffusion.web.ui.outpaint_ui import (
outpaint_inf,
outpaint_api,
outpaint_web,
outpaint_custom_model,
outpaint_hf_model_id,
outpaint_gallery,
outpaint_init_image,
outpaint_status,
@@ -53,10 +45,8 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
)
from apps.stable_diffusion.web.ui.upscaler_ui import (
upscaler_inf,
upscaler_api,
upscaler_web,
upscaler_custom_model,
upscaler_hf_model_id,
upscaler_gallery,
upscaler_init_image,
upscaler_status,
@@ -78,7 +68,8 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,

View File

@@ -0,0 +1,55 @@
from apps.stable_diffusion.web.ui.utils import (
HSLHue,
hsl_color,
get_lora_metadata,
)
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]

View File

@@ -105,6 +105,18 @@ body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
@@ -117,16 +129,12 @@ body {
padding: 0 var(--size-4) !important;
}
.container {
background-color: black !important;
padding-top: var(--size-5) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
@@ -250,10 +258,39 @@ footer {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
line-height: var(--line-xs);
}
.output_icon_button {

View File

@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)

View File

@@ -12,6 +12,10 @@ from apps.language_models.langchain.enums import (
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from gpt_langchain import (
path_to_docs,
create_or_update_db,
)
from apps.stable_diffusion.src import args
@@ -21,129 +25,148 @@ def user(message, history):
sharkModel = 0
sharded_model = 0
h2ogpt_model = 0
past_key_values = None
model_map = {
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
start_message = {
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
start_message = """
SHARK DocuChat
Chat with an AI, contextualized with provided files.
"""
def create_prompt(model_name, history):
system_message = start_message[model_name]
def create_prompt(history):
system_message = start_message
for item in history:
print("His item: ", item)
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
conversation = "<|endoftext|>".join(
[
"<|endoftext|><|answer|>".join([item[0], item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def chat(curr_system_message, history, model, device, precision):
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global sharded_model
global past_key_values
global h2ogpt_model
global sharkModel
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
from apps.language_models.langchain.h2oai_pipeline import generate_token
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")
if h2ogpt_model == 0:
if "cuda" in device:
shark_device = "cuda"
elif "sync" in device:
shark_device = "cpu"
elif "task" in device:
shark_device = "cpu"
elif "vulkan" in device:
shark_device = "vulkan"
else:
print("unrecognized device")
# if h2ogpt_model == 0:
# if "cuda" in device:
# device = "cuda"
# elif "sync" in device:
# device = "cpu-sync"
# elif "task" in device:
# device = "cpu-task"
# elif "vulkan" in device:
# device = "vulkan"
# else:
# print("unrecognized device")
device = "cpu" if shark_device == "cpu" else "cuda"
# max_toks = 128 if model_name == "codegen" else 512
# h2ogpt_model = UnshardedVicuna(
# model_name,
# hf_model_path=model_path,
# device=device,
# precision=precision,
# max_num_tokens=max_toks,
# )
# prompt = create_prompt(model_name, history)
# print("prompt = ", prompt)
args.device = shark_device
args.precision = precision
# for partial_text in h2ogpt_model.generate(prompt):
# history[-1][1] = partial_text
# yield history
output = gen.evaluate(
None, # model_state
None, # my_db_state
None, # instruction
None, # iinput
history, # context
False, # stream_output
None, # prompt_type
None, # prompt_dict
None, # temperature
None, # top_p
None, # top_k
None, # num_beams
None, # max_new_tokens
None, # min_new_tokens
None, # early_stopping
None, # max_time
None, # repetition_penalty
None, # num_return_sequences
False, # do_sample
False, # chat
None, # instruction_nochat
curr_system_message, # iinput_nochat
"Disabled", # langchain_mode
LangChainAction.QUERY.value, # langchain_action
3, # top_k_docs
True, # chunk
512, # chunk_size
[DocumentChoices.All_Relevant.name], # document_choice
from apps.language_models.langchain.gen import Langchain
langchain = Langchain(device, precision)
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
load_4bit=True
if device == "cuda"
else False, # load model in 4bit if device is cuda to save memory
load_gptq="",
use_safetensors=False,
infer_devices=True,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
model_state = dict(
model=h2ogpt_model,
tokenizer=h2ogpt_tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
from apps.language_models.langchain.h2oai_pipeline import (
H2OGPTSHARKModel,
)
sharkModel = H2OGPTSHARKModel()
prompt = create_prompt(history)
output_dict = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
iinput="",
context="",
stream_output=True,
prompt_type="prompt_answer",
prompt_dict={
"promptA": "",
"promptB": "",
"PreInstruct": "<|prompt|>",
"PreInput": None,
"PreResponse": "<|answer|>",
"terminate_response": [
"<|prompt|>",
"<|answer|>",
"<|endoftext|>",
],
"chat_sep": "<|endoftext|>",
"chat_turn_sep": "<|endoftext|>",
"humanstr": "<|prompt|>",
"botstr": "<|answer|>",
"generates_leading_space": False,
},
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
min_new_tokens=0,
early_stopping=False,
max_time=180,
repetition_penalty=1.07,
num_return_sequences=1,
do_sample=False,
chat=True,
instruction_nochat=prompt,
iinput_nochat="",
langchain_mode="UserData",
langchain_action=LangChainAction.QUERY.value,
top_k_docs=3,
chunk=True,
chunk_size=512,
document_choice=[DocumentChoices.All_Relevant.name],
concurrency_count=1,
memory_restriction_level=2,
raise_generate_gpu_exceptions=False,
@@ -154,24 +177,28 @@ def chat(curr_system_message, history, model, device, precision):
db_type="chroma",
n_jobs=-1,
first_para=False,
max_max_time=60 * 2,
model_state0=model_state,
model_lock=True,
user_path=userpath_selector.value,
)
output = generate_token(sharkModel, **output_dict)
for partial_text in output:
history[-1][1] = partial_text
yield history
return history
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
interactive=True,
container=True,
)
with gr.Blocks(title="DocuChat") as h2ogpt_web:
with gr.Row():
model_choices = list(
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
)
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
@@ -185,6 +212,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
@@ -220,7 +248,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
@@ -228,7 +256,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
@@ -240,3 +268,100 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
import pathlib
upload_path = None
database = None
database_directory = os.path.abspath(
"apps/language_models/langchain/db_path/"
)
def read_path():
global upload_path
filenames = [
[f]
for f in os.listdir(upload_path)
if os.path.isfile(os.path.join(upload_path, f))
]
filenames.sort()
return filenames
def upload_file(f):
names = []
for tmpfile in f:
name = tmpfile.name.split("/")[-1]
basename = os.path.join(upload_path, name)
with open(basename, "wb") as w:
with open(tmpfile.name, "rb") as r:
w.write(r.read())
update_or_create_db()
return read_path()
def update_userpath(newpath):
global upload_path
upload_path = newpath
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
return read_path()
def update_or_create_db():
global database
global upload_path
sources = path_to_docs(
upload_path,
verbose=True,
fail_any_exception=False,
n_jobs=-1,
chunk=True,
chunk_size=512,
url=None,
enable_captions=False,
captions_model=None,
caption_loader=None,
enable_ocr=False,
)
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
database = create_or_update_db(
"chroma",
database_directory,
"UserData",
sources,
False,
True,
True,
"sentence-transformers/all-MiniLM-L6-v2",
)
def first_run():
global database
if database is None:
update_or_create_db()
update_userpath(
os.path.abspath("apps/language_models/langchain/user_path/")
)
h2ogpt_upload.load(fn=first_run)
h2ogpt_web.load(fn=first_run)
with gr.Column():
text = gr.DataFrame(
col_count=(1, "fixed"),
type="array",
label="Documents",
value=read_path(),
)
with gr.Row():
upload = gr.UploadButton(
label="Upload documents",
file_count="multiple",
)
upload.upload(fn=upload_file, inputs=upload, outputs=text)
userpath_selector.render()
userpath_selector.input(
fn=update_userpath, inputs=userpath_selector, outputs=text
).then(fn=update_or_create_db)

View File

@@ -3,10 +3,9 @@ import torch
import time
import gradio as gr
import PIL
from math import ceil
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -16,6 +15,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -29,6 +29,7 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -50,12 +51,11 @@ def img2img_inf(
steps: int,
strength: float,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -67,6 +67,7 @@ def img2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -101,21 +102,17 @@ def img2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -230,10 +227,12 @@ def img2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"STRENGTH": strength}
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
@@ -243,7 +242,7 @@ def img2img_inf(
batch_size,
height,
width,
steps,
ceil(steps / strength),
strength,
guidance_scale,
seeds[current_batch],
@@ -253,6 +252,7 @@ def img2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -277,87 +277,6 @@ def img2img_inf(
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Img2Img Rest API.
def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["denoising_strength"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
use_stencil=InputData["use_stencil"]
if "use_stencil" in InputData.keys()
else "None",
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -376,31 +295,19 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
i2i_model_info = (
f"Custom Model Path: {str(get_custom_model_path())}"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models",
info=i2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
choices=get_custom_model_files() + predefined_models,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
@@ -415,6 +322,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -430,7 +339,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lines=2,
elem_id="negative_prompt_box",
)
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
@@ -445,7 +354,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
],
)
def show_canvas(choice):
@@ -506,6 +421,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
@@ -522,6 +438,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -529,6 +450,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -548,15 +470,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
@@ -579,11 +492,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -617,25 +545,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -647,13 +568,26 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{i2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(
@@ -679,7 +613,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
batch_size,
scheduler,
img2img_custom_model,
img2img_hf_model_id,
custom_vae,
precision,
device,
@@ -691,6 +624,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
resample_type,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
@@ -711,3 +645,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -4,9 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -16,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -49,12 +48,11 @@ def inpaint_inf(
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -89,21 +87,17 @@ def inpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -181,10 +175,13 @@ def inpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
@@ -225,86 +222,6 @@ def inpaint_inf(
return generated_imgs, text_output
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
res = inpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
{"image": init_image, "mask": mask},
InputData["height"],
InputData["width"],
InputData["is_full_res"],
InputData["full_res_padding"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -324,34 +241,21 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models",
info=inpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
inpaint_vae_info = (
@@ -366,6 +270,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -403,6 +309,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -414,6 +321,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -421,6 +333,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -514,25 +427,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -544,14 +450,26 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{inpaint_model_info}\n"
"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(
@@ -578,7 +496,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
batch_size,
scheduler,
inpaint_custom_model,
inpaint_hf_model_id,
custom_vae,
precision,
device,
@@ -608,3 +525,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -3,7 +3,7 @@ import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.src import prompt_examples, args, utils
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -50,6 +50,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -73,6 +74,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -105,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Row():
height = gr.Slider(
@@ -168,13 +171,16 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
value=utils.parse_seed_input(args.seed)[0],
precision=0,
label="Seed",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
with gr.Column(scale=2):

View File

@@ -0,0 +1,194 @@
# ========================================
# Gradio Setting
# ========================================
import gradio as gr
# from apps.language_models.src.pipelines.minigpt4_pipeline import (
# # MiniGPT4,
# CONV_VISION,
# )
from pathlib import Path
chat = None
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return (
None,
gr.update(value=None, interactive=True),
gr.update(
placeholder="Please upload your image first", interactive=False
),
gr.update(value="Upload & Start Chat", interactive=True),
chat_state,
img_list,
)
def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
global chat
if chat is None:
from apps.language_models.src.pipelines.minigpt4_pipeline import (
MiniGPT4,
CONV_VISION,
)
vision_model_precision = precision
if precision in ["int4", "int8"]:
vision_model_precision = "fp16"
vision_model_vmfb_path = Path(
f"vision_model_{vision_model_precision}_{device}.vmfb"
)
qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb")
chat = MiniGPT4(
model_name="MiniGPT4",
hf_model_path=None,
max_new_tokens=30,
device=device,
precision=precision,
_compile=_compile,
vision_model_vmfb_path=vision_model_vmfb_path,
qformer_vmfb_path=qformer_vmfb_path,
)
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return (
gr.update(interactive=False),
gr.update(interactive=True, placeholder="Type and press Enter"),
gr.update(value="Start Chatting", interactive=False),
chat_state,
img_list,
)
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return (
gr.update(
interactive=True, placeholder="Input should not be empty!"
),
chatbot,
chat_state,
)
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return "", chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000,
)[0]
print(llm_message)
print("************")
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """<h1 align="center">MultiModal SHARK (experimental)</h1>"""
description = """<h3>Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
"""
# TODO show examples below
with gr.Blocks() as minigpt4_web:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",
interactive=True,
variant="primary",
)
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
device = gr.Dropdown(
label="Device",
value="cuda",
# if enabled
# else "Only CUDA Supported for now",
choices=["cuda"],
interactive=False,
allow_custom_value=True,
)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label="MiniGPT-4")
text_input = gr.Textbox(
label="User",
placeholder="Please upload your image first",
interactive=False,
)
precision = gr.Radio(
label="Precision",
value="int8",
choices=[
"int8",
"fp16",
"fp32",
],
visible=True,
)
_compile = gr.Checkbox(
value=False,
label="Compile",
interactive=True,
)
upload_button.click(
upload_img,
[image, text_input, chat_state, device, precision, _compile],
[image, text_input, upload_button, chat_state, img_list],
)
text_input.submit(
gradio_ask,
[text_input, chatbot, chat_state],
[text_input, chatbot, chat_state],
).then(
gradio_answer,
[chatbot, chat_state, img_list, num_beams, temperature],
[chatbot, chat_state, img_list],
)
clear.click(
gradio_reset,
[chat_state, img_list],
[chatbot, image, text_input, upload_button, chat_state, img_list],
queue=False,
)

View File

@@ -98,6 +98,7 @@ with gr.Blocks() as model_web:
choices=None,
value=None,
visible=False,
allow_custom_value=True,
)
# TODO: select and SendTo
civit_models = gr.Gallery(

View File

@@ -3,9 +3,8 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -49,12 +48,11 @@ def outpaint_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -88,21 +86,17 @@ def outpaint_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -178,7 +172,10 @@ def outpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
left = True if "left" in directions else False
right = True if "right" in directions else False
@@ -230,88 +227,6 @@ def outpaint_inf(
return generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Inpaint Rest API.
def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["pixels"],
InputData["mask_blur"],
InputData["directions"],
InputData["noise_q"],
InputData["color_variation"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -329,36 +244,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models",
info=outpaint_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-inpainting",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
outpaint_vae_info = (
@@ -373,8 +274,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
@@ -408,6 +310,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -419,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -426,6 +334,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -542,25 +451,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -572,13 +474,26 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{outpaint_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -606,7 +521,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
batch_size,
scheduler,
outpaint_custom_model,
outpaint_hf_model_id,
custom_vae,
precision,
device,
@@ -636,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -91,7 +91,7 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
columns=4,
)
with gr.Column(scale=4):
@@ -109,6 +109,7 @@ with gr.Blocks() as outputgallery_web:
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,
@@ -203,6 +204,9 @@ with gr.Blocks() as outputgallery_web:
),
]
def on_image_columns_change(columns):
return gr.Gallery.update(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
@@ -364,53 +368,6 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
@@ -422,32 +379,35 @@ with gr.Blocks() as outputgallery_web:
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
)
outputgallery_filename.change(
on_outputgallery_filename_change,
@@ -476,7 +436,7 @@ with gr.Blocks() as outputgallery_web:
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
@@ -488,4 +448,4 @@ with gr.Blocks() as outputgallery_web:
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)

View File

@@ -6,7 +6,10 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
def user(message, history):
@@ -22,66 +25,81 @@ past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
def create_prompt(model_name, history):
system_message = start_message[model_name]
def create_prompt(model_name, history, prompt_prefix):
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -90,76 +108,172 @@ def set_vicuna_model(model):
vicuna_model = model
def get_default_config():
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
def chat(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
model_name, model_path = list(map(str.strip, model.split("=>")))
device, device_id = clean_device_info(device)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
if args.iree_rocm_target_chip != "":
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
@@ -195,17 +309,9 @@ def llm_chat_api(InputData: dict):
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
vicuna_model = UnshardedVicuna(
model_name,
@@ -213,6 +319,9 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
@@ -261,6 +370,13 @@ def llm_chat_api(InputData: dict):
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
@@ -270,13 +386,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -284,19 +400,43 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="fp16",
value="int4",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
visible=False,
)
chatbot = gr.Chatbot(height=500)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row():
with gr.Column():
msg = gr.Textbox(
@@ -311,24 +451,47 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(

View File

@@ -4,18 +4,19 @@ import time
import sys
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -26,10 +27,12 @@ from apps.stable_diffusion.src import (
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -46,12 +49,11 @@ def txt2img_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -62,6 +64,11 @@ def txt2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
hiresfix_height: int,
hiresfix_width: int,
hiresfix_strength: float,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -84,21 +91,17 @@ def txt2img_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -138,6 +141,11 @@ def txt2img_inf(
args.max_length = max_length
args.height = height
args.width = width
args.use_hiresfix = use_hiresfix
args.hiresfix_height = hiresfix_height
args.hiresfix_width = hiresfix_width
args.hiresfix_strength = hiresfix_strength
args.resample_type = resample_type
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
@@ -178,8 +186,11 @@ def txt2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
@@ -197,6 +208,81 @@ def txt2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
)
# TODO: allow user to save original image
# TODO: add option to let user keep both pipelines loaded, and unload
# either at will
# TODO: add custom step value slider
# TODO: add option to use secondary model for the img2img pass
if use_hiresfix is True:
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
1,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil="None",
ondemand=ondemand,
)
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
1,
hiresfix_height,
hiresfix_width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
out_imgs[0],
batch_size,
hiresfix_height,
hiresfix_width,
ceil(steps / hiresfix_strength),
hiresfix_strength,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
@@ -216,70 +302,6 @@ def txt2img_inf(
return generated_imgs, text_output, ""
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Text2Img Rest API.
def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
InputData["negative_prompt"],
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -299,32 +321,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_custom_model = gr.Dropdown(
label=f"Models",
info=t2i_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
choices=["None"]
+ get_custom_model_files()
choices=get_custom_model_files()
+ predefined_models,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
t2i_vae_info = (
@@ -340,6 +348,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
@@ -376,6 +386,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -387,6 +398,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -394,8 +410,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Group():
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
@@ -480,26 +497,54 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -518,13 +563,26 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{t2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -549,7 +607,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
batch_size,
scheduler,
txt2img_custom_model,
txt2img_hf_model_id,
custom_vae,
precision,
device,
@@ -560,6 +617,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
use_hiresfix,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
resample_type,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",
@@ -594,7 +656,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
@@ -610,9 +671,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
width,
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
choices=scheduler_list,
value="SharkEulerDiscrete",
)
use_hiresfix.change(
fn=set_compatible_schedulers,
inputs=[use_hiresfix],
outputs=[scheduler],
queue=False,
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,9 +3,7 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -15,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -42,12 +41,11 @@ def upscaler_inf(
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
model_id: str,
custom_vae: str,
precision: str,
device: str,
@@ -85,21 +83,17 @@ def upscaler_inf(
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
else:
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = get_custom_model_pathfile(custom_model)
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"):
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = custom_model
args.hf_model_id = model_id
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
@@ -177,8 +171,11 @@ def upscaler_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
extra_info = {"NOISE LEVEL": noise_level}
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
low_res_img = image
@@ -249,83 +246,6 @@ def upscaler_inf(
yield generated_imgs, text_output, ""
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(images):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
if args.output_img_format.lower() == "png":
image.save(output_bytes, format="PNG")
elif args.output_img_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG")
else:
raise HTTPException(
status_code=500, detail="Invalid image format"
)
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
# Upscaler Rest API.
def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
InputData["prompt"],
InputData["negative_prompt"],
init_image,
InputData["height"],
InputData["width"],
InputData["steps"],
InputData["noise_level"],
InputData["cfg_scale"],
InputData["seed"],
batch_count=1,
batch_size=1,
scheduler="EulerDiscrete",
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
custom_vae="None",
precision="fp16",
device=available_devices[0],
max_length=64,
save_metadata_to_json=False,
save_metadata_to_png=False,
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
"info": res[1],
}
with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -343,36 +263,22 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
f"Custom Model Path: {str(get_custom_model_path())}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models",
info=upscaler_model_info,
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-x4-upscaler",
choices=["None"]
+ get_custom_model_files(
choices=get_custom_model_files(
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
allow_custom_value=True,
scale=2,
)
# janky fix for overflowing text
upscaler_vae_info = (
@@ -387,6 +293,8 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
scale=1,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -422,6 +330,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -433,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -440,6 +354,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Scheduler",
value="DDIM",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -534,25 +449,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
visible=False,
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -564,14 +472,26 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at "
value=f"{upscaler_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
lines=2,
elem_id="std_output",
show_label=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -595,7 +515,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
batch_size,
scheduler,
upscaler_custom_model,
upscaler_hf_model_id,
custom_vae,
precision,
device,
@@ -625,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,10 +1,16 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
import math
import json
import safetensors
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
from enum import IntEnum
from apps.stable_diffusion.src import get_available_devices
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
@@ -25,7 +31,16 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str
ondemand: str # should this be expecting a bool instead?
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
custom_model_filetypes = (
@@ -161,6 +176,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
return use_weight
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_custom_model_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:
@@ -170,4 +248,5 @@ def cancel_sd():
nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()

View File

@@ -0,0 +1,105 @@
import os
import sys
import webview
import webview.util
import socket
from contextlib import closing
from multiprocessing import Process
from apps.stable_diffusion.src import args
def webview2_installed():
if sys.platform != "win32":
return False
# On windows we want to ensure we have MS webview2 available so we don't fall back
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
# apparently causes SHARK not to load in properly.
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
import winreg
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
# only way can find if a registry entry even exists is to try and open it
try:
# check for an all user install
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")
# if it didn't exist, we want to continue on...
except WindowsError:
try:
# ...to check for a current user install
with winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")
except WindowsError:
value = None
finally:
return (value is not None) and value != "" and value != "0.0.0.0"
def window(address):
from tkinter import Tk
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
def usable_port():
# Make sure we can actually use the port given in args.server_port. If
# not ask the OS for a port and return that as our port to use.
port = args.server_port
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
try:
sock.bind(("0.0.0.0", port))
except OSError:
with closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.bind(("0.0.0.0", 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock.getsockname()[1]
return port
def launch(port):
# setup to launch as an app if app mode has been requested and we're able
# to do it, answering whether we succeeded.
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
try:
t = Process(target=window, args=[f"http://localhost:{port}"])
t.start()
return True
except webview.util.WebViewException:
return False
else:
return False

View File

@@ -149,7 +149,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
@@ -175,10 +174,8 @@ def import_png_metadata(
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
@@ -217,7 +214,6 @@ def import_png_metadata(
width,
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,

View File

@@ -5,11 +5,25 @@ from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
def clear_tmp_mlir():
cleanup_start = time()
print(
"Clearing .mlir temporary files from a prior run. This may take some time..."
)
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
print(
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
)
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"gpu",
"cuda",
marks=pytest.mark.skipif(
check_device_drivers("gpu"), reason="nvidia-smi not found"
check_device_drivers("cuda"), reason="nvidia-smi not found"
),
),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
def compare_images(new_filename, golden_filename, upload=False):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt":
if os.name != "nt" and upload == True:
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
"gs://shark_tank/testdata/builder/",
]
)
raise SystemExit("new and golden not close")
raise AssertionError("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,5 +1,6 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python build_tools/stable_diffusion_testing.py --gen
python tank/generate_sharktank.py

View File

@@ -63,7 +63,14 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -71,7 +78,10 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
@@ -81,6 +91,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
@@ -103,6 +115,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this
@@ -181,7 +195,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
else:
print(command)
print("failed to generate image for this configuration")
@@ -200,6 +221,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -218,15 +242,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()

View File

@@ -0,0 +1,14 @@
import os
from sys import executable
import subprocess
from apps.language_models.scripts import vicuna
def test_loop():
precisions = ["fp16", "int8", "int4"]
devices = ["cpu"]
for precision in precisions:
for device in devices:
model = vicuna.UnshardedVicuna(device=device, precision=precision)
model.compile()
del model

View File

@@ -27,7 +27,7 @@ include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```

140
docs/shark_sd_koboldcpp.md Normal file
View File

@@ -0,0 +1,140 @@
# Overview
In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working.
## In Action
![preview](https://user-images.githubusercontent.com/121311569/280557602-bb97bad0-fdf5-4922-a2cc-4f327f2760db.jpg)
## Memory considerations
Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation.
## Performance Considerations
When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination.
It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there.
This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from.
## Setup SHARK and prerequisites:
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:
![SHARK API startup](https://user-images.githubusercontent.com/121311569/280556294-c3f7fc1a-c8e2-467d-afe6-365638d6823a.png)
* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests.
## Configure Koboldcpp for local image generation:
* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm).
* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp.
* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown:
![Settings button location](https://user-images.githubusercontent.com/121311569/280556246-10692d79-e89f-4fdf-87ba-82f3d78ed49d.png)
![Advanced Settings with 'Local A1111' location](https://user-images.githubusercontent.com/121311569/280556234-6ebc8ba7-1469-442a-93a7-5626a094ddf1.png)
*if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)*
* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)).
* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box:
![Add Image Button](https://user-images.githubusercontent.com/121311569/280556161-846c7883-4a83-4458-a56a-bd9f93ca354c.png)
...or by selecting the 'Autogenerate' option in the settings:
![Setting the autogenerate images option](https://user-images.githubusercontent.com/121311569/280556230-ae221a46-ba68-499b-a519-c8f290bbbeae.png)
*I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off*
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556694-55cd1c55-a059-4b54-9293-63d66a32368e.png)
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
![Entering extra image styles](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
## Connecting to SHARK on a different address or port
If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error:
![Can't find the A1111 endpoint error](https://user-images.githubusercontent.com/121311569/280555857-601f53dc-35e9-4027-9180-baa61d2393ba.png)
As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080:
* Select the cog icon the Generate Images section of Advanced settings:
![Selecting the endpoint cog](https://user-images.githubusercontent.com/121311569/280555866-4287ecc5-f29f-4c03-8f5a-abeaf31b0442.png)
* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080:
![Changing the endpoint port](https://user-images.githubusercontent.com/121311569/280556170-f8848b7b-6fc9-4cf7-80eb-5c312f332fd9.png)
* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly:
![Changing the endpoint hostname](https://user-images.githubusercontent.com/121311569/280556167-c6541dea-0f85-417a-b661-fdf4dc40d05f.png)
## Examples
Here's how Koboldcpp shows an image being requested:
![An image being generated]((https://user-images.githubusercontent.com/121311569/280556210-bb1c9efd-79ac-478e-b726-b25b82ef2186.png)
The generated image in context in story mode:
![A generated image](https://user-images.githubusercontent.com/121311569/280556179-4e9f3752-f349-4cba-bc6a-f85f8dc79b10.jpg)
And the same image when clicked on:
![A selected image](https://user-images.githubusercontent.com/121311569/280556216-2ca4c0a4-3889-4ef5-8a09-30084fb34081.jpg)
## Where to find the images in SHARK
Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri.
However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory.
You can browse these, either using the Output Gallery tab from within the SHARK web ui:
![SHARK web ui output gallery tab](https://user-images.githubusercontent.com/121311569/280556582-9303ca85-2594-4a8c-97a2-fbd72337980b.jpg)
...or by browsing to the `output_dir` in your operating system's file manager:
![SHARK output directory subfolder in Windows File Explorer](https://user-images.githubusercontent.com/121311569/280556297-66173030-2324-415c-a236-ef3fcd73e6ed.jpg)

View File

@@ -1,192 +0,0 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

View File

@@ -1,100 +0,0 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/shark-runtime
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

Some files were not shown because too many files have changed in this diff Show More