Compare commits

...

146 Commits

Author SHA1 Message Date
Ean Garvey
0010804517 Merge branch 'main' into ean-pytest-bench 2023-08-02 14:32:46 -05:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
Vivek Khandelwal
98fb6c52df Expand pipelines to fix streaming of tokens 2023-07-31 22:11:01 +05:30
Stefan Kapusniak
206c1b70f4 UI/Web: Reorder tabs to separate SD and LLM (#1701)
Shuffle the tabs around so that:

* All the SD tabs are together
* All the LLM tabs are together
* All the experimental tabs are together
2023-07-29 22:25:30 -04:00
PhaneeshB
cdb037ee54 use shark_args for vulkan debug utils flag 2023-07-30 07:54:26 +05:30
PhaneeshB
ce2fd84538 fix cpu device name for SharkStudio 2023-07-30 07:54:26 +05:30
PhaneeshB
4684afad34 update upscalar example 2023-07-28 21:06:28 +05:30
PhaneeshB
8d65456b7a Move vulkan runtime flags to shark_args 2023-07-28 21:06:28 +05:30
PhaneeshB
d6759a852b add vulkan vma alloc flag 2023-07-28 21:06:28 +05:30
Daniel Garvey
ab57af43c1 Couple of fixes for vicuna.py (#1696)
* mega vicuna merge pt 2

* add fallback to ensure compile is called
2023-07-27 15:53:05 -07:00
jinchen62
4d5c55dd9f Fix vicuna script (#1697) 2023-07-27 17:24:26 -05:00
Vivek Khandelwal
07399ad65c [Langchain] Remove unused code (#1698) 2023-07-27 11:59:54 -05:00
Vivek Khandelwal
776a9c2293 Fix for Langchain (#1694)
For CPU, remove max time stopping criteria
Fix web UI issue
2023-07-26 09:00:23 -07:00
Eliasj42
9d399eb988 fixed bug where device_idx was hardcoded (#1693)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-07-25 19:00:13 -05:00
Vivek Khandelwal
927b662aa7 Add Langchain SHARK Compilation support for all paths 2023-07-25 22:15:42 +05:30
Abhishek Varma
47f8a79c75 [MiniGPT4] Add MiniGPT4 to SHARK (#1554)
* [MiniGPT4] Add MiniGPT4 to SHARK

-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* Add int8 support for MiniGPT4

-- This commit adds int8 support for MiniGPT4.

Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>

* Update .spec for MiniGPT4's config files

* black format MiniGPT4

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>
2023-07-25 09:42:27 -07:00
Ean Garvey
f03bdeed74 Merge branch 'main' into ean-pytest-bench 2023-07-25 11:18:47 -05:00
Stefan Kapusniak
289f983f41 SD - Implement seed arrays for batch runs (#1690)
* SD Scripts and UI tabs that support batch_count can now take a
string containing a JSON array, or a list of integers, as their seed
input.
* Each batch in a run will now take the seed specified at the
corresponding array index if one exists. If there is no seed at
that index, the seed value will be treated as -1 and a random
seed will be assigned at that position. If an integer rather than
a list or json array has been, everything works as before.
* UI seed input controls are now Textboxes with info lines about
the seed formats allowed.
* UI error handling updated to be more helpful if the seed input is
invalid.
2023-07-24 19:22:34 -07:00
Daniel Garvey
453e46562f mega vicuna merge pt 2 (#1685) 2023-07-24 12:42:20 -05:00
Gaurav Shukla
5497af1f56 [config] Add support for uploading sharding config file in chatbot (#1689)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-07-24 10:18:03 -07:00
Vivek Khandelwal
f3cb63fc9c Fix Langchain multiple device isssue (#1688) 2023-07-24 08:03:46 -07:00
Vivek Khandelwal
d7092aafaa Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
2023-07-21 21:56:27 +05:30
Vivek Khandelwal
a415f3f70e Fix Langchain Prompt issue and add web UI support (#1682) 2023-07-21 06:36:55 -07:00
Ean Garvey
6870cba402 Merge branch 'main' into ean-pytest-bench 2023-07-20 12:10:48 -05:00
Vivek Khandelwal
c292e5c9d7 Add Langchain CPU support and update requirements 2023-07-20 18:53:34 +05:30
Vivek Khandelwal
03c4d9e171 Add support for Llama-2-70b for web and cli, and for hf_auth_token 2023-07-20 14:57:48 +05:30
jinchen62
3662224c04 Update brevitas requirement (#1677)
also clean up useless args

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 22:03:32 -07:00
Vivek Khandelwal
db3f222933 Revert "Add Llama2 70B option in CLI and WebUI (#1673)" (#1679)
This reverts commit 41e5088908.
2023-07-19 22:02:48 -07:00
Stefan Kapusniak
68b3021325 Fixes cosmetic problems with Gradio 3.37.0 (#1676)
* Fix nod-ai logo having a white border
* Fix control labels having a black background
* Remove extra lower border below Save Prompt checkboxes in Txt2Img UI
2023-07-19 17:28:53 -07:00
AyaanShah2204
336469154d added copy-metadata for pyyaml (#1678) 2023-07-19 17:27:25 -07:00
Abhishek Varma
41e5088908 Add Llama2 70B option in CLI and WebUI (#1673) 2023-07-19 10:41:42 -07:00
PhaneeshB
0a8f7673f4 Add README for CodeGen server 2023-07-19 23:10:23 +05:30
PhaneeshB
c482ab78da fix second vic clearing for low mem device 2023-07-19 23:10:23 +05:30
Vivek Khandelwal
4be80f7158 Add support for the Llama-2 model 2023-07-19 20:57:08 +05:30
AyaanShah2204
536aba1424 unpinned torch_mlir (#1668)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 06:28:00 -07:00
Ean Garvey
dd738a0e02 small changes to opt_perf_comparison.py (#1670)
* Use longer prompts for OPT comparison script

* small tweaks
2023-07-19 06:26:50 -07:00
Daniel Garvey
8927cb0a2c set optional vmfb download (#1667) 2023-07-18 10:57:28 -07:00
Ean Garvey
d1ad69cc73 Fix pytest benchmarks and revive sharktank torch model generation. 2023-07-18 17:16:06 +00:00
Daniel Garvey
8c317e4809 fix cli for vicuna (#1666) 2023-07-18 10:03:40 -07:00
Vivek Khandelwal
b0136593df Add support for different compilation paths for DocuChat (#1665) 2023-07-18 09:49:44 -07:00
Vivek Khandelwal
11f62d7fac Minor fixes for MiniLM Training 2023-07-18 17:16:44 +05:30
powderluv
14559dd620 Update DocuChat as experimental (#1660) 2023-07-17 22:12:05 -07:00
AyaanShah2204
e503a3e8d6 fixed joblib import error (#1659) 2023-07-17 12:56:10 -07:00
AyaanShah2204
22a4254adf fixed pyinstaller path for langchain imports (#1658) 2023-07-17 12:19:21 -07:00
Vivek Khandelwal
ab01f0f048 Add Langchain model in SHARK (#1657)
* Add H2OGPT

* Add UI tab for h2ogpt

* Add source files from h2ogpt

* Add the rest of the files

* Add h2ogpt support

* Add SHARK Compilation support for langchain model for cli mode

---------

Co-authored-by: George Petterson <gpetters@protonmail.com>
2023-07-17 09:58:15 -07:00
Phaneesh Barwaria
c471d17cca codegen API (#1655) 2023-07-16 20:00:39 -07:00
Stefan Kapusniak
a2a436eb0c SD - Add repeatable (batch) seeds option (#1654)
* Generates the seeds for all batch_count batches being run up front
rather than generating the seed for a batch just before it is run.
* Adds a --repeatable_seeds argument defaulting to False
* When repeatable_seeds=True, the first seed for a set of batches will
also be used as the rng seed for the subsequent batch seeds in the run.
The rng seed is then reset.
* When repeatable_seeds=False, batch seeding works as currently.
* Update scripts under apps/scripts that support the batch_count
argument to also support the repeatable_seeds argument.
* UI/Web: Adds a checkbox element on each SD tab after batch count/size
for toggling repeatable seeds, and update _inf functions to take
this into account.
* UI/Web: Moves the Stop buttons out of the Advanced sections and next
to Generate to make things not fit quite so badly with the extra UI
elements.
* UI/Web: Fixes logging to the upscaler output text box not working
correctly when running multiple batches.
2023-07-15 16:22:41 -07:00
powderluv
1adb51b29d Update docker README.md 2023-07-15 14:31:56 -07:00
anush elangovan
aab2233e25 Add a dev Ubuntu 22.04 docker image 2023-07-15 16:25:37 +00:00
jinchen62
e20cd71314 Change to a separate pass to unpack quantized weights (#1652) 2023-07-15 04:54:53 -07:00
powderluv
5ec91143f5 add a HF accelerate requirement (#1651) 2023-07-14 05:56:12 -07:00
Ean Garvey
7cf19230e2 add perf comparison script for opt. (#1650) 2023-07-13 13:29:48 -05:00
powderluv
1bcf6b2c5b pin diffusers to 0.18.1 (#1648) 2023-07-13 01:02:24 -07:00
jinchen62
91027f8719 Remove done TODOs, a sup PR for #1644 (#1647) 2023-07-12 23:30:45 -07:00
powderluv
a909fc2e78 add tiktoken to spec file (#1646) 2023-07-12 16:12:02 -07:00
jinchen62
247f69cf9d Apply canonicalize for unpacking int4 (#1644)
- tested it unpacks int4 as expected
- tested it doesn't make difference on int8
2023-07-11 19:41:09 -07:00
PhaneeshB
3b8f7cc231 Add codegen support in UI + lint 2023-07-11 21:58:01 +05:30
PhaneeshB
6e8dbf72bd mlir/vmfb path fixes for vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
38e5b62d80 adapt UI to send model details to pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
1c7eecc981 add codegen support in vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
be417f0bf4 fix precision for fp16 2023-07-11 21:58:01 +05:30
AyaanShah2204
a517e217b0 Added support for building ZIP distributions (#1639)
* added support for zip files

* making linter happy

* Added temporary fix for NoneType padding

* Removed zip script

* Added shared imports file

* making linter happy
2023-07-09 06:45:36 -07:00
Ranvir Singh Virk
9fcae4f808 Metal testing (#1595)
* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
2023-07-08 15:22:53 -07:00
Stefan Kapusniak
788d469c5b UI/Web Refix remaining gradio deprecation warning (#1638) 2023-07-08 13:48:36 -07:00
Stefan Kapusniak
8a59f7cc27 UI/Web add 'open folder' button to output gallery (#1634)
* Adds a button that opens the currently selected subdirectory using
the default OS file manager
* Improve output gallery handling of having images deleted out from
under it.
* Don't show VAE or LoRA lines in parameter info panel when their
value is 'None'
* Use a css class for small icon buttons on the output gallery
tab instead using the same id for multiple buttons
2023-07-08 12:44:59 -07:00
Stefan Kapusniak
1c2ec3c7a2 Some Fixes for Gradio 3.36.1 (#1637)
* Clear .style deprecation warnings.
* Re-remove download button from Nod logos.
* Add work around for `container=false` not doing what it did before on
dropdowns to the output gallery CSS
2023-07-08 11:20:34 -07:00
powderluv
af0f715e20 Unpin gradio 2023-07-08 09:41:14 -07:00
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
Eliasj42
4015793f84 changed method of compiling vicuna to remove first and second vicuna (#1611)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-03 12:12:43 -07:00
Ean Garvey
d63ce76dd8 Use sortable image filenames for SD outputs. (#1528) 2023-07-03 10:30:47 -07:00
Prashant Kumar
1c32915570 Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615) 2023-07-01 08:30:58 -07:00
Ean Garvey
6d286c0609 Enable tuning for rectangle sizes on rdna2. (#1608) 2023-06-30 22:28:24 -07:00
Stefan Kapusniak
7392b22731 UI/Web Reduce animation of default --progress_bars (#1613) 2023-06-30 21:12:10 -07:00
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
powderluv
5ce6001b41 Update stablelm_ui.py to default to fp16 2023-06-23 22:55:47 -07:00
powderluv
501d0ca52e Add sentencepiece to webui for pyinstaller 2023-06-23 22:52:06 -07:00
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
136 changed files with 24190 additions and 2439 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py

View File

@@ -54,8 +54,8 @@ jobs:
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_sd_${{ env.package_version_ }}.exe
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets

View File

@@ -35,6 +35,8 @@ jobs:
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
@@ -46,6 +48,8 @@ jobs:
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
@@ -61,7 +65,6 @@ jobs:
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
@@ -84,9 +87,6 @@ jobs:
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
@@ -115,6 +115,7 @@ jobs:
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
@@ -129,15 +130,14 @@ jobs:
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'

8
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -157,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode
@@ -187,3 +189,7 @@ apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData

View File

@@ -0,0 +1,16 @@
## CodeGen Setup using SHARK-server
### Setup Server
- clone SHARK and setup the venv
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
- default server address is `http://0.0.0.0:8080`
### Setup Client
1. fauxpilot-vscode (VSCode Extension):
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.

View File

@@ -0,0 +1,18 @@
# Langchain
## How to run the model
1.) Install all the dependencies by running:
```shell
pip install -r apps/language_models/langchain/langchain_requirements.txt
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
```
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
Now, you are ready to use the model.
3.) To run the model, run the following command:
```shell
python apps/language_models/langchain/gen.py --cli=True
```

View File

@@ -0,0 +1,186 @@
import copy
import torch
from evaluate_params import eval_func_param_names
from gen import Langchain
from prompter import non_hf_types
from utils import clear_torch_cache, NullContext, get_kwargs
def run_cli( # for local function:
base_model=None,
lora_weights=None,
inference_server=None,
debug=None,
chat_context=None,
examples=None,
memory_restriction_level=None,
# for get_model:
score_model=None,
load_8bit=None,
load_4bit=None,
load_half=None,
load_gptq=None,
use_safetensors=None,
infer_devices=None,
tokenizer_base_model=None,
gpu_id=None,
local_files_only=None,
resume_download=None,
use_auth_token=None,
trust_remote_code=None,
offload_folder=None,
compile_model=None,
# for some evaluate args
stream_output=None,
prompt_type=None,
prompt_dict=None,
temperature=None,
top_p=None,
top_k=None,
num_beams=None,
max_new_tokens=None,
min_new_tokens=None,
early_stopping=None,
max_time=None,
repetition_penalty=None,
num_return_sequences=None,
do_sample=None,
chat=None,
langchain_mode=None,
langchain_action=None,
document_choice=None,
top_k_docs=None,
chunk=None,
chunk_size=None,
# for evaluate kwargs
src_lang=None,
tgt_lang=None,
concurrency_count=None,
save_dir=None,
sanitize_bot_response=None,
model_state0=None,
max_max_new_tokens=None,
is_public=None,
max_max_time=None,
raise_generate_gpu_exceptions=None,
load_db_if_exists=None,
dbs=None,
user_path=None,
detect_user_path_changes_every_query=None,
use_openai_embedding=None,
use_openai_model=None,
hf_embedding_model=None,
db_type=None,
n_jobs=None,
first_para=None,
text_limit=None,
verbose=None,
cli=None,
reverse_docs=None,
use_cache=None,
auto_reduce_chunks=None,
max_chunks=None,
model_lock=None,
force_langchain_evaluate=None,
model_state_none=None,
# unique to this function:
cli_loop=None,
):
Langchain.check_locals(**locals())
score_model = "" # FIXME: For now, so user doesn't have to pass
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
device = "cpu" if n_gpus == 0 else "cuda"
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
with context_class(device):
from functools import partial
# get score model
smodel, stokenizer, sdevice = Langchain.get_score_model(
reward_type=True,
**get_kwargs(
Langchain.get_score_model,
exclude_names=["reward_type"],
**locals()
)
)
model, tokenizer, device = Langchain.get_model(
reward_type=False,
**get_kwargs(
Langchain.get_model, exclude_names=["reward_type"], **locals()
)
)
model_dict = dict(
base_model=base_model,
tokenizer_base_model=tokenizer_base_model,
lora_weights=lora_weights,
inference_server=inference_server,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
)
model_state = dict(model=model, tokenizer=tokenizer, device=device)
model_state.update(model_dict)
my_db_state = [None]
fun = partial(
Langchain.evaluate,
model_state,
my_db_state,
**get_kwargs(
Langchain.evaluate,
exclude_names=["model_state", "my_db_state"]
+ eval_func_param_names,
**locals()
)
)
example1 = examples[-1] # pick reference example
all_generations = []
while True:
clear_torch_cache()
instruction = input("\nEnter an instruction: ")
if instruction == "exit":
break
eval_vars = copy.deepcopy(example1)
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
eval_func_param_names.index("instruction_nochat")
] = instruction
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
eval_func_param_names.index("iinput_nochat")
] = "" # no input yet
eval_vars[
eval_func_param_names.index("context")
] = "" # no context yet
# grab other parameters, like langchain_mode
for k in eval_func_param_names:
if k in locals():
eval_vars[eval_func_param_names.index(k)] = locals()[k]
gener = fun(*tuple(eval_vars))
outr = ""
res_old = ""
for gen_output in gener:
res = gen_output["response"]
extra = gen_output["sources"]
if base_model not in non_hf_types or base_model in ["llama"]:
if not stream_output:
print(res)
else:
# then stream output for gradio that has full output each generation, so need here to show only new chars
diff = res[len(res_old) :]
print(diff, end="", flush=True)
res_old = res
outr = res # don't accumulate
else:
outr += res # just is one thing
if extra:
# show sources at end after model itself had streamed to std rest of response
print(extra, flush=True)
all_generations.append(outr + "\n")
if not cli_loop:
break
return all_generations

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
from enum import Enum
class PromptType(Enum):
custom = -1
plain = 0
instruct = 1
quality = 2
human_bot = 3
dai_faq = 4
summarize = 5
simple_instruct = 6
instruct_vicuna = 7
instruct_with_end = 8
human_bot_orig = 9
prompt_answer = 10
open_assistant = 11
wizard_lm = 12
wizard_mega = 13
instruct_vicuna2 = 14
instruct_vicuna3 = 15
wizard2 = 16
wizard3 = 17
instruct_simple = 18
wizard_vicuna = 19
openai = 20
openai_chat = 21
gptj = 22
prompt_answer_openllama = 23
vicuna11 = 24
mptinstruct = 25
mptchat = 26
falcon = 27
class DocumentChoices(Enum):
All_Relevant = 0
All_Relevant_Only_Sources = 1
Only_All_Sources = 2
Just_LLM = 3
non_query_commands = [
DocumentChoices.All_Relevant_Only_Sources.name,
DocumentChoices.Only_All_Sources.name,
]
class LangChainMode(Enum):
"""LangChain mode"""
DISABLED = "Disabled"
CHAT_LLM = "ChatLLM"
LLM = "LLM"
ALL = "All"
WIKI = "wiki"
WIKI_FULL = "wiki_full"
USER_DATA = "UserData"
MY_DATA = "MyData"
GITHUB_H2OGPT = "github h2oGPT"
H2O_DAI_DOCS = "DriverlessAI docs"
class LangChainAction(Enum):
"""LangChain action"""
QUERY = "Query"
# WIP:
# SUMMARIZE_MAP = "Summarize_map_reduce"
SUMMARIZE_MAP = "Summarize"
SUMMARIZE_ALL = "Summarize_all"
SUMMARIZE_REFINE = "Summarize_refine"
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
# from site-packages/langchain/llms/openai.py
# but needed since ChatOpenAI doesn't have this information
model_token_mapping = {
"gpt-4": 8192,
"gpt-4-0314": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16 * 1024,
"gpt-3.5-turbo-0301": 4096,
"text-ada-001": 2049,
"ada": 2049,
"text-babbage-001": 2040,
"babbage": 2049,
"text-curie-001": 2049,
"curie": 2049,
"davinci": 2049,
"text-davinci-003": 4097,
"text-davinci-002": 4097,
"code-davinci-002": 8001,
"code-davinci-001": 8001,
"code-cushman-002": 2048,
"code-cushman-001": 2048,
}
source_prefix = "Sources [Score | Link]:"
source_postfix = "End Sources<p>"

View File

@@ -0,0 +1,53 @@
no_default_param_names = [
"instruction",
"iinput",
"context",
"instruction_nochat",
"iinput_nochat",
]
gen_hyper = [
"temperature",
"top_p",
"top_k",
"num_beams",
"max_new_tokens",
"min_new_tokens",
"early_stopping",
"max_time",
"repetition_penalty",
"num_return_sequences",
"do_sample",
]
eval_func_param_names = (
[
"instruction",
"iinput",
"context",
"stream_output",
"prompt_type",
"prompt_dict",
]
+ gen_hyper
+ [
"chat",
"instruction_nochat",
"iinput_nochat",
"langchain_mode",
"langchain_action",
"top_k_docs",
"chunk",
"chunk_size",
"document_choice",
]
)
# form evaluate defaults for submit_nochat_api
eval_func_param_names_defaults = eval_func_param_names.copy()
for k in no_default_param_names:
if k in eval_func_param_names_defaults:
eval_func_param_names_defaults.remove(k)
eval_extra_columns = ["prompt", "response", "score"]

View File

@@ -0,0 +1,432 @@
"""Load question answering chains."""
from __future__ import annotations
from typing import (
Any,
Mapping,
Optional,
Dict,
List,
Sequence,
Tuple,
Union,
Protocol,
)
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from abc import ABC, abstractmethod
from langchain.chains.base import Chain
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length.
"""
return None
@abstractmethod
def combine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string."""
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
class LLMChain(Chain):
"""Chain to run queries against LLMs.
Example:
.. code-block:: python
from langchain import LLMChain, OpenAI, PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list, callbacks=callbacks)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in result
]
else:
return result
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(
input_variables=["page_content"], template="{page_content}"
)
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context."""
llm_chain: LLMChain
"""LLM wrapper to use after formatting documents."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
)
"""Prompt to use to format each document."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
# Format each document according to the prompt
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
return inputs
def prompt_length(
self, docs: List[Document], **kwargs: Any
) -> Optional[int]:
"""Get the prompt length by formatting the prompt."""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
**kwargs,
)
def load_qa_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "map_rerank", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,380 @@
import inspect
import os
from functools import partial
from typing import Dict, Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from langchain.llms import gpt4all
from dotenv import dotenv_values
from utils import FakeTokenizer
def get_model_tokenizer_gpt4all(base_model, **kwargs):
# defaults (some of these are generation parameters, so need to be passed in at generation time)
model_kwargs = dict(
n_threads=os.cpu_count() // 2,
temp=kwargs.get("temperature", 0.2),
top_p=kwargs.get("top_p", 0.75),
top_k=kwargs.get("top_k", 40),
n_ctx=2048 - 256,
)
env_gpt4all_file = ".env_gpt4all"
model_kwargs.update(dotenv_values(env_gpt4all_file))
# make int or float if can to satisfy types for class
for k, v in model_kwargs.items():
try:
if float(v) == int(v):
model_kwargs[k] = int(v)
else:
model_kwargs[k] = float(v)
except:
pass
if base_model == "llama":
if "model_path_llama" not in model_kwargs:
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
model_path = model_kwargs.pop("model_path_llama")
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
from llama_cpp import Llama
# llama sets some things at init model time, not generation time
func_names = list(inspect.signature(Llama.__init__).parameters)
model_kwargs = {
k: v for k, v in model_kwargs.items() if k in func_names
}
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
model = Llama(model_path=model_path, **model_kwargs)
elif base_model in "gpt4all_llama":
if (
"model_name_gpt4all_llama" not in model_kwargs
and "model_path_gpt4all_llama" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gpt4all_llama")
model_type = "llama"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
elif base_model in "gptj":
if (
"model_name_gptj" not in model_kwargs
and "model_path_gptj" not in model_kwargs
):
raise ValueError(
"No model_name_gpt4j or model_path_gpt4j in %s"
% env_gpt4all_file
)
model_name = model_kwargs.pop("model_name_gptj")
model_type = "gptj"
from gpt4all import GPT4All as GPT4AllModel
model = GPT4AllModel(model_name=model_name, model_type=model_type)
else:
raise ValueError("No such base_model %s" % base_model)
return model, FakeTokenizer(), "cpu"
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# streaming to std already occurs without this
# sys.stdout.write(token)
# sys.stdout.flush()
pass
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
# default from class
model_kwargs = {
k: v.default
for k, v in dict(inspect.signature(cls).parameters).items()
if k not in exclude_list
}
# from our defaults
model_kwargs.update(default_kwargs)
# from user defaults
model_kwargs.update(env_kwargs)
# ensure only valid keys
func_names = list(inspect.signature(cls).parameters)
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
return model_kwargs
def get_llm_gpt4all(
model_name,
model=None,
max_new_tokens=256,
temperature=0.1,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
streaming=False,
callbacks=None,
prompter=None,
verbose=False,
):
assert prompter is not None
env_gpt4all_file = ".env_gpt4all"
env_kwargs = dotenv_values(env_gpt4all_file)
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
default_kwargs = dict(
context_erase=0.5,
n_batch=1,
n_ctx=n_ctx,
n_predict=max_new_tokens,
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
repeat_penalty=repetition_penalty,
temp=temperature,
temperature=temperature,
top_k=top_k,
top_p=top_p,
use_mlock=True,
verbose=verbose,
)
if model_name == "llama":
cls = H2OLlamaCpp
model_path = (
env_kwargs.pop("model_path_llama") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model_path=model_path,
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
llm.client.verbose = verbose
elif model_name == "gpt4all_llama":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gpt4all_llama")
if model is None
else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="llama",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
elif model_name == "gptj":
cls = H2OGPT4All
model_path = (
env_kwargs.pop("model_path_gptj") if model is None else model
)
model_kwargs = get_model_kwargs(
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
)
model_kwargs.update(
dict(
model=model_path,
backend="gptj",
callbacks=callbacks,
streaming=streaming,
prompter=prompter,
)
)
llm = cls(**model_kwargs)
else:
raise RuntimeError("No such model_name %s" % model_name)
return llm
class H2OGPT4All(gpt4all.GPT4All):
model: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
if isinstance(values["model"], str):
from gpt4all import GPT4All as GPT4AllModel
full_path = values["model"]
model_path, delimiter, model_name = full_path.rpartition("/")
model_path += delimiter
values["client"] = GPT4AllModel(
model_name=model_name,
model_path=model_path or None,
model_type=values["backend"],
allow_download=False,
)
if values["n_threads"] is not None:
# set n_threads
values["client"].model.set_thread_count(
values["n_threads"]
)
else:
values["client"] = values["model"]
try:
values["backend"] = values["client"].model_type
except AttributeError:
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
values["backend"] = values["client"].model.model_type
except ImportError:
raise ValueError(
"Could not import gpt4all python package. "
"Please install it with `pip install gpt4all`."
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
# Roughly 4 chars per token if natural language
prompt = prompt[-self.n_ctx * 4 :]
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
verbose = False
if verbose:
print("_call prompt: %s" % prompt, flush=True)
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
return super()._call(prompt, stop=stop, run_manager=run_manager)
from langchain.llms import LlamaCpp
class H2OLlamaCpp(LlamaCpp):
model_path: Any
prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
if isinstance(values["model_path"], str):
model_path = values["model_path"]
model_param_names = [
"lora_path",
"lora_base",
"n_ctx",
"n_parts",
"seed",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"n_threads",
"n_batch",
"use_mmap",
"last_n_tokens_size",
]
model_params = {k: values[k] for k in model_param_names}
# For backwards compatibility, only include if non-null.
if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"]
try:
from llama_cpp import Llama
values["client"] = Llama(model_path, **model_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
else:
values["client"] = values["model_path"]
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
verbose = False
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
prompt = prompt[-self.n_ctx * 4 :]
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > self.n_ctx:
# conservative by using int()
chars_per_token = int(len(prompt) / num_prompt_tokens)
prompt = prompt[-self.n_ctx * chars_per_token :]
if verbose:
print(
"reducing tokens, assuming average of %s chars/token: %s"
% chars_per_token,
flush=True,
)
prompt_tokens2 = self.client.tokenize(
b" " + prompt.encode("utf-8")
)
num_prompt_tokens2 = len(prompt_tokens2)
print(
"reduced tokens from %d -> %d"
% (num_prompt_tokens, num_prompt_tokens2),
flush=True,
)
# use instruct prompting
data_point = dict(context="", instruction=prompt, input="")
prompt = self.prompter.generate_prompt(data_point)
if verbose:
print("_call prompt: %s" % prompt, flush=True)
if self.streaming:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
if text_callback:
text_callback(prompt)
text = ""
for token in self.stream(
prompt=prompt, stop=stop, run_manager=run_manager
):
text_chunk = token["choices"][0]["text"]
# self.stream already calls text_callback
# if text_callback:
# text_callback(text_chunk)
text += text_chunk
return text
else:
params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
import traceback
from typing import Callable
import os
from gradio_client.client import Job
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
from gradio_client import Client
class GradioClient(Client):
"""
Parent class of gradio client
To handle automatically refreshing client if detect gradio server changed
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
super().__init__(*args, **kwargs)
self.server_hash = self.get_server_hash()
def get_server_hash(self):
"""
Get server hash using super without any refresh action triggered
Returns: git hash of gradio server
"""
return super().submit(api_name="/system_hash").result()
def refresh_client_if_should(self):
# get current hash in order to update api_name -> fn_index map in case gradio server changed
# FIXME: Could add cli api as hash
server_hash = self.get_server_hash()
if self.server_hash != server_hash:
self.refresh_client()
self.server_hash = server_hash
else:
self.reset_session()
def refresh_client(self):
"""
Ensure every client call is independent
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
Returns:
"""
# need session hash to be new every time, to avoid "generator already executing"
self.reset_session()
client = Client(*self.args, **self.kwargs)
for k, v in client.__dict__.items():
setattr(self, k, v)
def submit(
self,
*args,
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | list[Callable] | None = None,
) -> Job:
# Note predict calls submit
try:
self.refresh_client_if_should()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
except Exception as e:
print("Hit e=%s" % str(e), flush=True)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
# see if immediately failed
e = job.future._exception
if e is not None:
print(
"GR job failed: %s %s"
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
flush=True,
)
# force reconfig in case only that
self.refresh_client()
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
e2 = job.future._exception
if e2 is not None:
print(
"GR job failed again: %s\n%s"
% (
str(e2),
"".join(traceback.format_tb(e2.__traceback__)),
),
flush=True,
)
return job

View File

@@ -0,0 +1,802 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
import copy
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
global_device = "cuda"
global_precision = "fp16"
if not args.run_docuchat_web:
args.device = global_device
args.precision = global_precision
tensor_device = "cpu" if args.device == "cpu" else "cuda"
class H2OGPTModel(torch.nn.Module):
def __init__(self, device, precision):
super().__init__()
torch_dtype = (
torch.float32
if precision == "fp32" or device == "cpu"
else torch.float16
)
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
model_kwargs = {
"local_files_only": False,
"torch_dtype": torch_dtype,
"resume_download": True,
"use_auth_token": False,
"trust_remote_code": True,
"offload_folder": "offline_folder",
"device_map": device_map,
}
config = AutoConfig.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
use_auth_token=False,
trust_remote_code=True,
offload_folder="offline_folder",
)
self.model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
config=config,
**model_kwargs,
)
if precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model.transformer.h,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return output.logits[:, -1, :]
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_name = "h2ogpt_falcon_7b"
extended_model_name = (
model_name + "_" + args.precision + "_" + args.device
)
vmfb_path = Path(extended_model_name + ".vmfb")
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
need_to_compile = False
if not vmfb_path.exists():
need_to_compile = True
# Downloading VMFB from shark_tank
print("Trying to download pre-compiled vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
if vmfb_path.exists():
print(
"Pre-compiled vmfb downloaded from shark tank successfully."
)
need_to_compile = False
if need_to_compile:
if not mlir_path.exists():
print("Trying to download pre-generated mlir from shark tank.")
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Generating the mlir
bytecode = self.get_bytecode(tensor_device, args.precision)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
print("Compiled vmfb found. Loading it from: ", vmfb_path)
shark_module = SharkInference(
None, device=args.device, mlir_dialect="linalg"
)
shark_module.load_module(str(vmfb_path))
print("Compiled vmfb loaded successfully.")
else:
raise ValueError("Unable to download/generate a vmfb.")
self.model = shark_module
def get_bytecode(self, device, precision):
h2ogpt_model = H2OGPTModel(device, precision)
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 400)
).to(device=device)
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
device=device
)
h2ogptCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
h2ogpt_model,
h2ogptCompileInput,
is_f16=False,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del h2ogpt_model
del self.src_model
print(f"[DEBUG] generating torch mlir")
if precision in ["int4", "int8"]:
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
)
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
return bytecode
def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(
"forward",
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
)
).to(device=tensor_device)
return result
h2ogpt_model = H2OGPTSHARKModel()
def pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=400, do_truncation=False
):
inp_shape = input_ids.shape
if inp_shape[1] < max_padding_length:
# do padding
num_add_token = max_padding_length - inp_shape[1]
padded_input_ids = torch.cat(
[
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
input_ids,
],
dim=1,
)
padded_attention_mask = torch.cat(
[
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
attention_mask,
],
dim=1,
)
return padded_input_ids, padded_attention_mask
elif inp_shape[1] > max_padding_length or do_truncation:
# do truncation
num_remove_token = inp_shape[1] - max_padding_length
truncated_input_ids = input_ids[:, num_remove_token:]
truncated_attention_mask = attention_mask[:, num_remove_token:]
return truncated_input_ids, truncated_attention_mask
else:
return input_ids, attention_mask
class H2OTextGenerationPipeline(TextGenerationPipeline):
def __init__(
self,
*args,
debug=False,
chat=False,
stream_output=False,
sanitize_bot_response=False,
use_prompter=True,
prompter=None,
prompt_type=None,
prompt_dict=None,
max_input_tokens=2048 - 256,
**kwargs,
):
"""
HF-like pipeline, but handle instruction prompting and stopping (for some models)
:param args:
:param debug:
:param chat:
:param stream_output:
:param sanitize_bot_response:
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
:param prompter: prompter, can pass if have already
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
If use_prompter, then will make prompter and use it.
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
:param max_input_tokens:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.prompt_text = None
self.use_prompter = use_prompter
self.prompt_type = prompt_type
self.prompt_dict = prompt_dict
self.prompter = prompter
if self.use_prompter:
if self.prompter is not None:
assert self.prompter.prompt_type is not None
else:
self.prompter = Prompter(
self.prompt_type,
self.prompt_dict,
debug=debug,
chat=chat,
stream_output=stream_output,
)
self.human = self.prompter.humanstr
self.bot = self.prompter.botstr
self.can_stop = True
else:
self.prompter = None
self.human = None
self.bot = None
self.can_stop = False
self.sanitize_bot_response = sanitize_bot_response
self.max_input_tokens = (
max_input_tokens # not for generate, so ok that not kwargs
)
@staticmethod
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
if hasattr(tokenizer, "model_max_length"):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = tokenizer.model_max_length
if max_prompt_length is not None:
model_max_length = min(model_max_length, max_prompt_length)
# cut at some upper likely limit to avoid excessive tokenization etc
# upper bound of 10 chars/token, e.g. special chars sometimes are long
if len(prompt_text) > model_max_length * 10:
len0 = len(prompt_text)
prompt_text = prompt_text[-model_max_length * 10 :]
if verbose:
print(
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
flush=True,
)
else:
# unknown
model_max_length = None
num_prompt_tokens = None
if model_max_length is not None:
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https://github.com/h2oai/h2ogpt/issues/192
for trial in range(0, 3):
prompt_tokens = tokenizer(prompt_text)["input_ids"]
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
# keep tail, where question is if using langchain
prompt_text = prompt_text[
-model_max_length * chars_per_token :
]
if verbose:
print(
"reducing %s tokens, assuming average of %s chars/token for %s characters"
% (
num_prompt_tokens,
chars_per_token,
len(prompt_text),
),
flush=True,
)
else:
if verbose:
print(
"using %s tokens with %s chars"
% (num_prompt_tokens, len(prompt_text)),
flush=True,
)
break
return prompt_text, num_prompt_tokens
def preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
**generate_kwargs,
):
(
prompt_text,
num_prompt_tokens,
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
data_point = dict(context="", instruction=prompt_text, input="")
if self.prompter is not None:
prompt_text = self.prompter.generate_prompt(data_point)
self.prompt_text = prompt_text
if handle_long_generation is None:
# forces truncation of inputs to avoid critical failure
handle_long_generation = None # disable with new approaches
return super().preprocess(
prompt_text,
prefix=prefix,
handle_long_generation=handle_long_generation,
**generate_kwargs,
)
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
if self.use_prompter:
outputs = rec["generated_text"]
outputs = self.prompter.get_response(
outputs,
prompt=self.prompt_text,
sanitize_bot_response=self.sanitize_bot_response,
)
elif self.bot and self.human:
outputs = (
rec["generated_text"]
.split(self.bot)[1]
.split(self.human)[0]
)
else:
outputs = rec["generated_text"]
rec["generated_text"] = outputs
print(
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
flush=True,
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=tensor_device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
self.prompt_type,
self.prompt_dict,
self.tokenizer,
self.device,
human=self.human,
bot=self.bot,
model_max_length=self.tokenizer.model_max_length,
)
generate_kwargs["stopping_criteria"] = stopping_criteria
# return super()._forward(model_inputs, **generate_kwargs)
return self.__forward(model_inputs, **generate_kwargs)
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
def __forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
# generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens
is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length")
or self.model.config.max_length
)
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens
is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
# pad or truncate the input_ids and attention_mask
max_padding_length = 400
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"prompt_text": prompt_text,
}

View File

@@ -0,0 +1,247 @@
"""
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
But accepts preloaded model to avoid slowness in use and CUDA forking issues
Loader that loads image captions
By default, the loader utilizes the pre-trained BLIP image captioning model.
https://huggingface.co/Salesforce/blip-image-captioning-base
"""
from typing import List, Union, Any, Tuple
import requests
from langchain.docstore.document import Document
from langchain.document_loaders import ImageCaptionLoader
from utils import get_device, NullContext
import pkg_resources
try:
assert pkg_resources.get_distribution("bitsandbytes") is not None
have_bitsandbytes = True
except (pkg_resources.DistributionNotFound, AssertionError):
have_bitsandbytes = False
class H2OImageCaptionLoader(ImageCaptionLoader):
"""Loader that loads the captions of an image"""
def __init__(
self,
path_images: Union[str, List[str]] = None,
blip_processor: str = None,
blip_model: str = None,
caption_gpu=True,
load_in_8bit=True,
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
load_half=False,
load_gptq="",
use_safetensors=False,
min_new_tokens=20,
max_tokens=50,
):
if blip_model is None or blip_model is None:
blip_processor = "Salesforce/blip-image-captioning-base"
blip_model = "Salesforce/blip-image-captioning-base"
super().__init__(path_images, blip_processor, blip_model)
self.blip_processor = blip_processor
self.blip_model = blip_model
self.processor = None
self.model = None
self.caption_gpu = caption_gpu
self.context_class = NullContext
self.device = "cpu"
self.load_in_8bit = (
load_in_8bit and have_bitsandbytes
) # only for blip2
self.load_half = load_half
self.load_gptq = load_gptq
self.use_safetensors = use_safetensors
self.gpu_id = "auto"
# default prompt
self.prompt = "image of"
self.min_new_tokens = min_new_tokens
self.max_tokens = max_tokens
def set_context(self):
if get_device() == "cuda" and self.caption_gpu:
import torch
n_gpus = (
torch.cuda.device_count() if torch.cuda.is_available else 0
)
if n_gpus > 0:
self.context_class = torch.device
self.device = "cuda"
def load_model(self):
try:
import transformers
except ImportError:
raise ValueError(
"`transformers` package not found, please install with "
"`pip install transformers`."
)
self.set_context()
if self.caption_gpu:
if self.gpu_id == "auto":
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
# device_map = 'auto'
device_map = {"": 0}
else:
if self.device == "cuda":
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if "blip2" in self.blip_processor.lower():
from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration,
)
if self.load_half and not self.load_in_8bit:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor, device_map=device_map
).half()
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
).half()
)
else:
self.processor = Blip2Processor.from_pretrained(
self.blip_processor,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
self.model = (
Blip2ForConditionalGeneration.from_pretrained(
self.blip_model,
load_in_8bit=self.load_in_8bit,
device_map=device_map,
)
)
else:
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
)
self.load_half = False # not supported
if self.caption_gpu:
if device_map == "auto":
# Blip doesn't support device_map='auto'
if self.device == "cuda":
if self.gpu_id == "auto":
device_map = {"": 0}
else:
device_map = {"": self.gpu_id}
else:
device_map = {"": "cpu"}
else:
device_map = {"": "cpu"}
self.processor = BlipProcessor.from_pretrained(
self.blip_processor, device_map=device_map
)
self.model = (
BlipForConditionalGeneration.from_pretrained(
self.blip_model, device_map=device_map
)
)
return self
def set_image_paths(self, path_images: Union[str, List[str]]):
"""
Load from a list of image files
"""
if isinstance(path_images, str):
self.image_paths = [path_images]
else:
self.image_paths = path_images
def load(self, prompt=None) -> List[Document]:
if self.processor is None or self.model is None:
self.load_model()
results = []
for path_image in self.image_paths:
caption, metadata = self._get_captions_and_metadata(
model=self.model,
processor=self.processor,
path_image=path_image,
prompt=prompt,
)
doc = Document(page_content=caption, metadata=metadata)
results.append(doc)
return results
def _get_captions_and_metadata(
self, model: Any, processor: Any, path_image: str, prompt=None
) -> Tuple[str, dict]:
"""
Helper function for getting the captions and metadata of an image
"""
if prompt is None:
prompt = self.prompt
try:
from PIL import Image
except ImportError:
raise ValueError(
"`PIL` package not found, please install with `pip install pillow`"
)
try:
if path_image.startswith("http://") or path_image.startswith(
"https://"
):
image = Image.open(
requests.get(path_image, stream=True).raw
).convert("RGB")
else:
image = Image.open(path_image).convert("RGB")
except Exception:
raise ValueError(f"Could not get image data for {path_image}")
import torch
with torch.no_grad():
with self.context_class(self.device):
context_class_cast = (
NullContext if self.device == "cpu" else torch.autocast
)
with context_class_cast(self.device):
if self.load_half:
inputs = processor(
image, prompt, return_tensors="pt"
).half()
else:
inputs = processor(image, prompt, return_tensors="pt")
min_length = len(prompt) // 4 + self.min_new_tokens
self.max_tokens = max(self.max_tokens, min_length)
output = model.generate(
**inputs,
min_length=min_length,
max_length=self.max_tokens,
)
caption: str = processor.decode(
output[0], skip_special_tokens=True
)
prompti = caption.find(prompt)
if prompti >= 0:
caption = caption[prompti + len(prompt) :]
metadata: dict = {"image_path": path_image}
return caption, metadata

View File

@@ -0,0 +1,120 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
huggingface_hub==0.16.4
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
scikit-learn==1.2.2
alt-profanity-check==1.2.2
better-profanity==0.7.0
numpy==1.24.3
pandas==2.0.2
matplotlib==3.7.1
loralib==0.1.1
bitsandbytes==0.39.0
accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
transformers==4.30.2
tokenizers==0.13.3
APScheduler==3.10.1
# optional for generate
pynvml==11.5.0
psutil==5.9.5
boto3==1.26.101
botocore==1.29.101
# optional for finetune
tensorboard==2.13.0
neptune==1.2.0
# for gradio client
gradio_client==0.2.10
beautifulsoup4==4.12.2
markdown==3.4.3
# data and testing
pytest==7.2.2
pytest-xdist==3.2.1
nltk==3.8.1
textstat==0.7.3
# pandoc==2.3
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
pypandoc_binary==1.11; platform_machine == "x86_64"
pypandoc_binary==1.11; sys_platform == "win32"
openpyxl==3.1.2
lm_dataformat==0.0.20
bioc==2.0
# falcon
einops==0.6.1
instructorembedding==1.0.1
# for gpt4all .env file, but avoid worrying about imports
python-dotenv==1.0.0
text-generation==0.6.0
# for tokenization when don't have HF tokenizer
tiktoken==0.4.0
# optional: for OpenAI endpoint or embeddings (requires key)
openai==0.27.8
# optional for chat with PDF
langchain==0.0.202
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5
# for HF embeddings
sentence_transformers==2.2.2
# local vector db
chromadb==0.3.25
# server vector db
#pymilvus==2.2.8
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
# unstructured==0.8.1
# strong support for images
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
unstructured[local-inference]==0.7.4
#pdf2image==1.16.3
#pytesseract==0.3.10
pillow
pdfminer.six==20221105
urllib3
requests_file
#pdf2image==1.16.3
#pytesseract==0.3.10
tabulate==0.9.0
# FYI pandoc already part of requirements.txt
# JSONLoader, but makes some trouble for some users
# jq==1.4.1
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0
# weaviate vector db
weaviate-client==3.22.1
gpt4all==1.0.5
llama-cpp-python==0.1.73
arxiv==1.4.8
pymupdf==1.22.5 # AGPL license
# extract-msg==0.41.1 # GPL3
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
playwright==1.36.0
# requires Chrome binary to be in path
selenium==4.10.0

View File

@@ -0,0 +1,124 @@
from typing import List, Optional, Tuple
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
print(
"Replacing original LLaMa attention with flash attention", flush=True
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,109 @@
import functools
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
# NOTE: Some models need specific new prompt_type
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
if load_gptq:
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
use_triton = False
functools.partial(
AutoGPTQForCausalLM.from_quantized,
quantize_config=None,
use_triton=use_triton,
)
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
if llama_type is None:
llama_type = "llama" in model_name.lower()
if llama_type:
from transformers import LlamaForCausalLM, LlamaTokenizer
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
elif "distilgpt2" in model_name.lower():
from transformers import AutoModelForCausalLM, AutoTokenizer
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
elif "gpt2" in model_name.lower():
from transformers import GPT2LMHeadModel, GPT2Tokenizer
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
elif "mbart-" in model_name.lower():
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
return (
MBartForConditionalGeneration.from_pretrained,
MBart50TokenizerFast,
)
elif (
"t5" == model_name.lower()
or "t5-" in model_name.lower()
or "flan-" in model_name.lower()
):
from transformers import AutoTokenizer, T5ForConditionalGeneration
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
elif "bigbird" in model_name:
from transformers import (
BigBirdPegasusForConditionalGeneration,
AutoTokenizer,
)
return (
BigBirdPegasusForConditionalGeneration.from_pretrained,
AutoTokenizer,
)
elif (
"bart-large-cnn-samsum" in model_name
or "flan-t5-base-samsum" in model_name
):
from transformers import pipeline
return pipeline, "summarization"
elif (
reward_type
or "OpenAssistant/reward-model".lower() in model_name.lower()
):
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
return (
AutoModelForSequenceClassification.from_pretrained,
AutoTokenizer,
)
else:
from transformers import AutoTokenizer, AutoModelForCausalLM
model_loader = AutoModelForCausalLM
tokenizer_loader = AutoTokenizer
return model_loader.from_pretrained, tokenizer_loader
def get_tokenizer(
tokenizer_loader,
tokenizer_base_model,
local_files_only,
resume_download,
use_auth_token,
):
tokenizer = tokenizer_loader.from_pretrained(
tokenizer_base_model,
local_files_only=local_files_only,
resume_download=resume_download,
use_auth_token=use_auth_token,
padding_side="left",
)
tokenizer.pad_token_id = 0 # different from the eos token
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left,
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
tokenizer.padding_side = "left" # Allow batched inference
return tokenizer

View File

@@ -0,0 +1,208 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
get_some_dbs_from_hf,
all_db_zips,
some_db_zips,
create_or_update_db,
)
from utils import get_ngpus_vis
def glob_to_db(
user_path,
chunk=True,
chunk_size=512,
verbose=False,
fail_any_exception=False,
n_jobs=-1,
url=None,
enable_captions=True,
captions_model=None,
caption_loader=None,
enable_ocr=False,
):
sources1 = path_to_docs(
user_path,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
chunk=chunk,
chunk_size=chunk_size,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
return sources1
def make_db_main(
use_openai_embedding: bool = False,
hf_embedding_model: str = None,
persist_directory: str = "db_dir_UserData",
user_path: str = "user_path",
url: str = None,
add_if_exists: bool = True,
collection_name: str = "UserData",
verbose: bool = False,
chunk: bool = True,
chunk_size: int = 512,
fail_any_exception: bool = False,
download_all: bool = False,
download_some: bool = False,
download_one: str = None,
download_dest: str = "./",
n_jobs: int = -1,
enable_captions: bool = True,
captions_model: str = "Salesforce/blip-image-captioning-base",
pre_load_caption_model: bool = False,
caption_gpu: bool = True,
enable_ocr: bool = False,
db_type: str = "chroma",
):
"""
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
python make_db.py
# once db is made, can use in generate.py like:
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
or zip-up the db_dir_UserData and share:
zip -r db_dir_UserData.zip db_dir_UserData
# To get all db files (except large wiki_full) do:
python make_db.py --download_some=True
# To get a single db file from HF:
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
:param use_openai_embedding: Whether to use OpenAI embedding
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
:param persist_directory: where to persist db
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
:param url: url to generate documents from (None means user_path is not None)
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
:param collection_name: Collection name for new db if not adding
:param verbose: whether to show verbose messages
:param chunk: whether to chunk data
:param chunk_size: chunk size for chunking
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
:param download_some: whether to download some small example databases from h2o.ai HF
:param download_one: whether to download one chosen example databases from h2o.ai HF
:param download_dest: Destination for downloads
:param n_jobs: Number of cores to use for ingesting multiple files
:param enable_captions: Whether to enable captions on images
:param captions_model: See generate.py
:param pre_load_caption_model: See generate.py
:param caption_gpu: Caption images on GPU if present
:param enable_ocr: Whether to enable OCR on images
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
:return: None
"""
db = None
# match behavior of main() in generate.py for non-HF case
n_gpus = get_ngpus_vis()
if n_gpus == 0:
if hf_embedding_model is None:
# if no GPUs, use simpler embedding model to avoid cost in time
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
else:
if hf_embedding_model is None:
# if still None, then set default
hf_embedding_model = "hkunlp/instructor-large"
if download_all:
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_some:
print(
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
)
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
if verbose:
print("DONE", flush=True)
return db, collection_name
elif download_one:
print("Downloading %s (and unzipping)" % download_one, flush=True)
get_some_dbs_from_hf(
download_dest, db_zips=[[download_one, "", "Unknown License"]]
)
if verbose:
print("DONE", flush=True)
return db, collection_name
if enable_captions and pre_load_caption_model:
# preload, else can be too slow or if on GPU have cuda context issues
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
from image_captions import H2OImageCaptionLoader
caption_loader = H2OImageCaptionLoader(
None,
blip_model=captions_model,
blip_processor=captions_model,
caption_gpu=caption_gpu,
).load_model()
else:
if enable_captions:
caption_loader = "gpu" if caption_gpu else "cpu"
else:
caption_loader = False
if verbose:
print("Getting sources", flush=True)
assert (
user_path is not None or url is not None
), "Can't have both user_path and url as None"
if not url:
assert os.path.isdir(user_path), (
"user_path=%s does not exist" % user_path
)
sources = glob_to_db(
user_path,
chunk=chunk,
chunk_size=chunk_size,
verbose=verbose,
fail_any_exception=fail_any_exception,
n_jobs=n_jobs,
url=url,
enable_captions=enable_captions,
captions_model=captions_model,
caption_loader=caption_loader,
enable_ocr=enable_ocr,
)
exceptions = [x for x in sources if x.metadata.get("exception")]
print("Exceptions: %s" % exceptions, flush=True)
sources = [x for x in sources if "exception" not in x.metadata]
assert len(sources) > 0, "No sources found"
db = create_or_update_db(
db_type,
persist_directory,
collection_name,
sources,
use_openai_embedding,
add_if_exists,
verbose,
hf_embedding_model,
)
assert db is not None
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,403 @@
"""Load Data from a MediaWiki dump xml."""
import ast
import glob
import pickle
import uuid
from typing import List, Optional
import os
import bz2
import csv
import numpy as np
import pandas as pd
import pytest
from matplotlib import pyplot as plt
from langchain.docstore.document import Document
from langchain.document_loaders import MWDumpLoader
# path where downloaded wiki files exist, to be processed
root_path = "/data/jon/h2o-llm"
def unescape(x):
try:
x = ast.literal_eval(x)
except:
try:
x = x.encode("ascii", "ignore").decode("unicode_escape")
except:
pass
return x
def get_views():
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
views = pd.read_csv("wiki_page_views_more_5000month.csv")
views.index = views["title"]
views = views["views"]
views = views.to_dict()
views = {str(unescape(str(k))): v for k, v in views.items()}
views2 = {k.replace("_", " "): v for k, v in views.items()}
# views has _ but pages has " "
views.update(views2)
return views
class MWDumpDirectLoader(MWDumpLoader):
def __init__(
self,
data: str,
encoding: Optional[str] = "utf8",
title_words_limit=None,
use_views=True,
verbose=True,
):
"""Initialize with file path."""
self.data = data
self.encoding = encoding
self.title_words_limit = title_words_limit
self.verbose = verbose
if use_views:
# self.views = get_views()
# faster to use global shared values
self.views = global_views
else:
self.views = None
def load(self) -> List[Document]:
"""Load from file path."""
import mwparserfromhell
import mwxml
dump = mwxml.Dump.from_page_xml(self.data)
docs = []
for page in dump.pages:
if self.views is not None and page.title not in self.views:
if self.verbose:
print("Skipped %s low views" % page.title, flush=True)
continue
for revision in page:
if self.title_words_limit is not None:
num_words = len(" ".join(page.title.split("_")).split(" "))
if num_words > self.title_words_limit:
if self.verbose:
print("Skipped %s" % page.title, flush=True)
continue
if self.verbose:
if self.views is not None:
print(
"Kept %s views: %s"
% (page.title, self.views[page.title]),
flush=True,
)
else:
print("Kept %s" % page.title, flush=True)
code = mwparserfromhell.parse(revision.text)
text = code.strip_code(
normalize=True, collapse=True, keep_template_params=False
)
title_url = str(page.title).replace(" ", "_")
metadata = dict(
title=page.title,
source="https://en.wikipedia.org/wiki/" + title_url,
id=page.id,
redirect=page.redirect,
views=self.views[page.title]
if self.views is not None
else -1,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_index(search_term, index_filename):
byte_flag = False
data_length = start_byte = 0
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
for line in csv_reader:
if not byte_flag and search_term == line[2]:
start_byte = int(line[0])
byte_flag = True
elif byte_flag and int(line[0]) != start_byte:
data_length = int(line[0]) - start_byte
break
index_file.close()
return start_byte, data_length
def get_start_bytes(index_filename):
index_file = open(index_filename, "r")
csv_reader = csv.reader(index_file, delimiter=":")
start_bytes = set()
for line in csv_reader:
start_bytes.add(int(line[0]))
index_file.close()
return sorted(start_bytes)
def get_wiki_filenames():
# requires
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
base_path = os.path.join(
root_path, "enwiki-20230401-pages-articles-multistream"
)
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
index_filename = os.path.join(base_path, index_file)
wiki_filename = os.path.join(
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
)
return index_filename, wiki_filename
def get_documents_by_search_term(search_term):
index_filename, wiki_filename = get_wiki_filenames()
start_byte, data_length = search_index(search_term, index_filename)
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(data.decode())
documents = loader.load()
return documents
def get_one_chunk(
wiki_filename,
start_byte,
end_byte,
return_file=True,
title_words_limit=None,
use_views=True,
):
data_length = end_byte - start_byte
with open(wiki_filename, "rb") as wiki_file:
wiki_file.seek(start_byte)
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
loader = MWDumpDirectLoader(
data.decode(), title_words_limit=title_words_limit, use_views=use_views
)
documents1 = loader.load()
if return_file:
base_tmp = "temp_wiki"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
with open(filename, "wb") as f:
pickle.dump(documents1, f)
return filename
return documents1
from joblib import Parallel, delayed
global_views = get_views()
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
print("DO get all wiki docs: %s" % small_test, flush=True)
index_filename, wiki_filename = get_wiki_filenames()
start_bytes = get_start_bytes(index_filename)
end_bytes = start_bytes[1:]
start_bytes = start_bytes[:-1]
if small_test:
start_bytes = start_bytes[:small_test]
end_bytes = end_bytes[:small_test]
if n_jobs is None:
n_jobs = 5
else:
if n_jobs is None:
n_jobs = os.cpu_count() // 4
# default loky backend leads to name space conflict problems
return_file = True # large return from joblib hangs
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
delayed(get_one_chunk)(
wiki_filename,
start_byte,
end_byte,
return_file=return_file,
use_views=use_views,
)
for start_byte, end_byte in zip(start_bytes, end_bytes)
)
if return_file:
# then documents really are files
files = documents.copy()
documents = []
for fil in files:
with open(fil, "rb") as f:
documents.extend(pickle.load(f))
os.remove(fil)
else:
from functools import reduce
from operator import concat
documents = reduce(concat, documents)
assert isinstance(documents, list)
print("DONE get all wiki docs", flush=True)
return documents
def test_by_search_term():
search_term = "Apollo"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Abstract (law)"
assert len(get_documents_by_search_term(search_term)) == 100
search_term = "Artificial languages"
assert len(get_documents_by_search_term(search_term)) == 100
def test_start_bytes():
index_filename, wiki_filename = get_wiki_filenames()
assert len(get_start_bytes(index_filename)) == 227850
def test_get_all_documents():
small_test = 20 # 227850
n_jobs = os.cpu_count() // 4
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=False
)
)
== small_test * 100
)
assert (
len(
get_all_documents(
small_test=small_test, n_jobs=n_jobs, use_views=True
)
)
== 429
)
def get_one_pageviews(fil):
df1 = pd.read_csv(
fil,
sep=" ",
header=None,
names=["region", "title", "views", "foo"],
quoting=csv.QUOTE_NONE,
)
df1.index = df1["title"]
df1 = df1[df1["region"] == "en"]
df1 = df1.drop("region", axis=1)
df1 = df1.drop("foo", axis=1)
df1 = df1.drop("title", axis=1) # already index
base_tmp = "temp_wiki_pageviews"
if not os.path.isdir(base_tmp):
os.makedirs(base_tmp, exist_ok=True)
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
df1.to_csv(filename, index=True)
return filename
def test_agg_pageviews(gen_files=False):
if gen_files:
path = os.path.join(
root_path,
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
)
files = glob.glob(os.path.join(path, "pageviews*.gz"))
# files = files[:2] # test
n_jobs = os.cpu_count() // 2
csv_files = Parallel(
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
)(delayed(get_one_pageviews)(fil) for fil in files)
else:
# to continue without redoing above
csv_files = glob.glob(
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
)
df_list = []
for csv_file in csv_files:
print(csv_file)
df1 = pd.read_csv(csv_file)
df_list.append(df1)
df = pd.concat(df_list, axis=0)
df = df.groupby("title")["views"].sum().reset_index()
df.to_csv("wiki_page_views.csv", index=True)
def test_reduce_pageview():
filename = "wiki_page_views.csv"
df = pd.read_csv(filename)
df = df[df["views"] < 1e7]
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
#
views_limit = 5000
df = df[df["views"] > views_limit]
filename = "wiki_page_views_more_5000month.csv"
df.to_csv(filename, index=True)
#
plt.hist(df["views"], bins=100, log=True)
views_avg = np.mean(df["views"])
views_median = np.median(df["views"])
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
plt.savefig(filename.replace(".csv", ".png"))
plt.close()
@pytest.mark.skip("Only if doing full processing again, some manual steps")
def test_do_wiki_full_all():
# Install other requirements for wiki specific conversion:
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
# Get index
os.system(
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
)
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
test_get_all_documents()
# Check can search wiki multistream
test_by_search_term()
# Test can get all start bytes in index
test_start_bytes()
# Get page views, e.g. for entire month of April 2023
os.system(
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
)
# Aggregate page views from many files into single file
test_agg_pageviews(gen_files=True)
# Reduce page views to some limit, so processing of full wiki is not too large
test_reduce_pageview()
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
# Note get_views as global() function done once is required to avoid very slow processing
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
"""
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
"""

View File

@@ -0,0 +1,121 @@
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from enums import PromptType
class StoppingCriteriaSub(StoppingCriteria):
def __init__(
self, stops=[], encounters=[], device="cuda", model_max_length=None
):
super().__init__()
assert (
len(stops) % len(encounters) == 0
), "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to(device) for stop in stops]
self.num_stops = [0] * len(stops)
self.model_max_length = model_max_length
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop) :])).item():
self.num_stops[stopi] += 1
if (
self.num_stops[stopi]
>= self.encounters[stopi % len(self.encounters)]
):
# print("Stopped", flush=True)
return True
if (
self.model_max_length is not None
and input_ids[0].shape[0] >= self.model_max_length
):
# critical limit
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False
def get_stopping(
prompt_type,
prompt_dict,
tokenizer,
device,
human="<human>:",
bot="<bot>:",
model_max_length=None,
):
# FIXME: prompt_dict unused currently
if prompt_type in [
PromptType.human_bot.name,
PromptType.instruct_vicuna.name,
PromptType.instruct_with_end.name,
]:
if prompt_type == PromptType.human_bot.name:
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
# stopping only starts once output is beyond prompt
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
stop_words = [human, bot, "\n" + human, "\n" + bot]
encounters = [1, 2]
elif prompt_type == PromptType.instruct_vicuna.name:
# even below is not enough, generic strings and many ways to encode
stop_words = [
"### Human:",
"""
### Human:""",
"""
### Human:
""",
"### Assistant:",
"""
### Assistant:""",
"""
### Assistant:
""",
]
encounters = [1, 2]
else:
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
stop_words = ["### End"]
encounters = [1]
stop_words_ids = [
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
for stop_word in stop_words
]
# handle single token case
stop_words_ids = [
x if len(x.shape) > 0 else torch.tensor([x])
for x in stop_words_ids
]
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
# avoid padding in front of tokens
if (
tokenizer._pad_token
): # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
for x in stop_words_ids
]
# handle fake \n added
stop_words_ids = [
x[1:] if y[0] == "\n" else x
for x, y in zip(stop_words_ids, stop_words)
]
# build stopper
stopping_criteria = StoppingCriteriaList(
[
StoppingCriteriaSub(
stops=stop_words_ids,
encounters=encounters,
device=device,
model_max_length=model_max_length,
)
]
)
else:
stopping_criteria = StoppingCriteriaList()
return stopping_criteria

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,69 @@
from typing import Any, Dict, List, Union, Optional
import time
import queue
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class StreamingGradioCallbackHandler(BaseCallbackHandler):
"""
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
"""
def __init__(self, timeout: Optional[float] = None, block=True):
super().__init__()
self.text_queue = queue.SimpleQueue()
self.stop_signal = None
self.do_stop = False
self.timeout = timeout
self.block = block
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running. Clean the queue."""
while not self.text_queue.empty():
try:
self.text_queue.get(block=False)
except queue.Empty:
continue
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.text_queue.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.text_queue.put(self.stop_signal)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.text_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
while True:
try:
value = (
self.stop_signal
) # value looks unused in pycharm, not true
if self.do_stop:
print("hit stop", flush=True)
# could raise or break, maybe best to raise and make parent see if any exception in thread
raise StopIteration()
# break
value = self.text_queue.get(
block=self.block, timeout=self.timeout
)
break
except queue.Empty:
time.sleep(0.01)
if value == self.stop_signal:
raise StopIteration()
else:
return value

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,503 @@
import torch
import dataclasses
from enum import auto, Enum
from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class LayerNorm(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class VisionModel(torch.nn.Module):
def __init__(
self,
ln_vision,
visual_encoder,
precision="fp32",
weight_group_size=128,
):
super().__init__()
self.ln_vision = ln_vision
self.visual_encoder = visual_encoder
if precision in ["int4", "int8"]:
print("Vision Model applying weight quantization to ln_vision")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.ln_vision,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print(
"Vision Model applying weight quantization to visual_encoder"
)
quantize_model(
self.visual_encoder,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, image):
image_embeds = self.ln_vision(self.visual_encoder(image))
return image_embeds
class QformerBertModel(torch.nn.Module):
def __init__(self, qformer_bert):
super().__init__()
self.qformer_bert = qformer_bert
def forward(self, query_tokens, image_embeds, image_atts):
query_output = self.qformer_bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
return query_output.last_hidden_state
class FirstLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("First Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, inputs_embeds, position_ids, attention_mask):
print("************************************")
print(
"inputs_embeds: ",
inputs_embeds.shape,
" dtype: ",
inputs_embeds.dtype,
)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("************************************")
config = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"past_key_values": None,
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SecondLlamaModel(torch.nn.Module):
def __init__(self, model, precision="fp32", weight_group_size=128):
super().__init__()
self.model = model
print("SHARK: Loading LLAMA Done")
if precision in ["int4", "int8"]:
print("Second Llama applying weight quantization")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
input_ids,
position_ids,
attention_mask,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
print("************************************")
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
print(
"position_ids: ",
position_ids.shape,
" dtype: ",
position_ids.dtype,
)
print(
"attention_mask: ",
attention_mask.shape,
" dtype: ",
attention_mask.dtype,
)
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
print("past_key_values dtype: ", i1.dtype)
print("************************************")
config = {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
),
"use_cache": True,
"attention_mask": attention_mask,
}
output = self.model(
**config,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return_vals = []
return_vals.append(output.logits)
temp_past_key_values = output.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop) :])).item():
return True
return False
CONV_VISION = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)

View File

@@ -1,14 +1,41 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
@@ -22,12 +49,36 @@ class FirstVicuna(torch.nn.Module):
class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,

View File

@@ -62,7 +62,105 @@ class SecondVicunaLayer(torch.nn.Module):
)
class CompiledFirstVicunaLayer(torch.nn.Module):
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, hidden_states):
output = self.model(hidden_states)
return output
class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output
class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
output = self.model(input_ids)
return output
class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
@@ -76,103 +174,55 @@ class CompiledFirstVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
def forward(
self,
input_ids,
is_first=True,
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)

View File

@@ -28,8 +28,9 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -40,7 +41,12 @@ parser.add_argument(
default=None,
help="path to falcon's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=False,
@@ -59,12 +65,12 @@ class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=Path("falcon.mlir"),
falcon_vmfb_path=Path("falcon.vmfb"),
falcon_mlir_path=None,
falcon_vmfb_path=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -85,7 +91,7 @@ class Falcon(SharkLLMBase):
return tokenizer
def get_src_model(self):
print("Loading src model")
print("Loading src model: ", self.model_name)
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -93,9 +99,26 @@ class Falcon(SharkLLMBase):
return falcon_model
def compile_falcon(self):
vmfb = get_vmfb_from_path(self.falcon_vmfb_path, self.device, "linalg")
if vmfb is not None:
return vmfb
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ "_"
+ self.device
+ ".vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
)
if vmfb is not None:
return vmfb
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
@@ -106,27 +129,26 @@ class Falcon(SharkLLMBase):
bytecode = f.read()
else:
mlir_generated = False
if args.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if not mlir_generated:
compilation_input_ids = torch.randint(
@@ -184,6 +206,7 @@ class Falcon(SharkLLMBase):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
@@ -192,17 +215,6 @@ class Falcon(SharkLLMBase):
return shark_module
def compile(self):
if (
not self.falcon_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/falcon/7b/cuda/falcon.vmfb",
self.falcon_vmfb_path.absolute(),
single_file=True,
)
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
@@ -375,6 +387,8 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
@@ -428,18 +442,35 @@ if __name__ == "__main__":
args = parser.parse_args()
falcon_mlir_path = (
Path("falcon.mlir")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ ".mlir"
)
if args.falcon_mlir_path is None
else Path(args.falcon_mlir_path)
)
falcon_vmfb_path = (
Path("falcon.vmfb")
Path(
"falcon_"
+ args.falcon_variant_to_use
+ "_"
+ args.precision
+ "_"
+ args.device
+ ".vmfb"
)
if args.falcon_vmfb_path is None
else Path(args.falcon_vmfb_path)
)
falcon = Falcon(
"falcon",
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
@@ -451,11 +482,16 @@ if __name__ == "__main__":
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True
print("\n-----\nScript executing for the following config: \n")
print("Falcon Model: ", falcon.model_name)
print("Precision: ", args.precision)
print("Device: ", args.device)
while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? True or False?: "
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt:
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
@@ -469,5 +505,8 @@ if __name__ == "__main__":
res_str,
)
continue_execution = input(
"\nDo you wish to run script one more time? True or False?: "
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
return cls()
def build(self, **kwargs):
cfg = OmegaConf.create(kwargs)
return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 224)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)

View File

@@ -0,0 +1,5 @@
datasets:
cc_sbu_align:
data_type: images
build_info:
storage: /path/to/cc_sbu_align/

View File

@@ -0,0 +1,33 @@
model:
arch: mini_gpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
freeze_qformer: True
# Q-Former
num_query_token: 32
# Vicuna
llama_model: "lmsys/vicuna-7b-v1.3"
# generation configs
prompt: ""
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@@ -0,0 +1,25 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: 'prerained_minigpt4_7b.pth'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

View File

@@ -0,0 +1,629 @@
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
import requests
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(
torch.meshgrid([coords_h, coords_w])
) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += (
window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.gamma_1
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (
img_size[0] // patch_size[0]
)
self.patch_shape = (
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(
-1
) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer(
"relative_position_index", relative_position_index
)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=None,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_mean_pooling=True,
init_scale=0.001,
use_checkpoint=False,
):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads
)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.patch_shape
if use_rel_pos_bias
else None,
)
for i in range(depth)
]
)
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.embed_dim, num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = (
self.rel_pos_bias() if self.rel_pos_bias is not None else None
)
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
# l.weight.data = l.weight.data.half()
l.weight.data = l.weight.data
if l.bias is not None:
# l.bias.data = l.bias.data.half()
l.bias.data = l.bias.data
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
def create_eva_vit_g(
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408 // 88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
local_filename = "eva_vit_g.pth"
response = requests.get(url)
if response.status_code == 200:
with open(local_filename, "wb") as f:
f.write(response.content)
print("File downloaded successfully.")
state_dict = torch.load(local_filename, map_location="cpu")
interpolate_pos_embed(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model

View File

@@ -0,0 +1,4 @@
<Img><ImageHere></Img> Describe this image in detail.
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
<Img><ImageHere></Img> Please provide a detailed description of the picture.
<Img><ImageHere></Img> Could you describe the contents of this image for me?

View File

@@ -1,559 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.utils import (
get_vmfb_from_path,
)
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def compile_first_vicuna(self):
vmfb = get_vmfb_from_path(
self.first_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.first_vicuna_mlir_path.exists():
with open(self.first_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = self.tokenizer(
compilation_prompt
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
print(f"[DEBUG] generating torch mlir")
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
def remove_constant_dim(line):
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
return line
module = str(module)
new_lines = []
print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
)
if (
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
in line
):
continue
new_lines.append(line)
module = "\n".join(new_lines)
print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
module = BytesIO(module)
bytecode = module.read()
del module
print(f"[DEBUG] writing mlir to file")
f_ = open(self.first_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.first_vicuna_vmfb_path.parent.absolute(),
self.first_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved first vic vmfb at vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile_second_vicuna(self):
vmfb = get_vmfb_from_path(
self.second_vicuna_vmfb_path, self.device, "tm_tensor"
)
if vmfb is not None:
return vmfb
# Compilation path needs some more work before it is functional
print(
f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
if self.second_vicuna_mlir_path.exists():
with open(self.second_vicuna_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim)", line
)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module_str = str(module)
new_lines = []
for line in module_str.splitlines():
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
module_str = "\n".join(new_lines)
bytecode = module_str.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(self.second_vicuna_mlir_path, "wb")
f_.write(bytecode)
f_.close()
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
)
path = shark_module.save_module(
self.second_vicuna_vmfb_path.parent.absolute(),
self.second_vicuna_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
print("Saved vmfb at ", str(path))
shark_module.load_module(self.second_vicuna_vmfb_path)
# self.shark_module = shark_module
return shark_module
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
}
generated_token_op = self.generate_new_token(params=params)
token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
if token == 2:
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
firstVicunaInput = (input_ids,)
assert first_vic is not None
output_first_vicuna = first_vic("forward", firstVicunaInput)
output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:])
logits_first_vicuna = torch.tensor(output_first_vicuna[0])
if cache_outputs:
torch.save(
logits_first_vicuna, "logits_first_vicuna_tensor.pt"
)
torch.save(
output_first_vicuna_tensor, "output_first_vicuna_tensor.pt"
)
token = torch.argmax(
torch.tensor(logits_first_vicuna)[:, -1, :], dim=1
)
return token, logits_first_vicuna, output_first_vicuna_tensor
def forward_second(sec_vic, inputs=None, load_inputs=False):
if inputs is not None:
logits = inputs[0]
pkv = inputs[1:]
elif load_inputs:
pkv = torch.load("output_first_vicuna_tensor.pt")
pkv = tuple(torch.tensor(x) for x in pkv)
logits = torch.load("logits_first_vicuna_tensor.pt")
else:
print(
"Either inputs must be given, or load_inputs must be true"
)
return None
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
secondVicunaOutput = sec_vic("forward", secondVicunaInput)
new_pkv = secondVicunaOutput[1:]
new_logits = secondVicunaOutput[0]
new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1)
return new_token, new_logits, new_pkv
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
fv = params["fv"]
token, logits, pkv = forward_first(
fv, # self.shark_model[0],
prompt=prompt,
cache_outputs=False,
)
else:
_logits = params["logits"]
_pkv = params["pkv"]
inputs = (_logits,) + tuple(_pkv)
sv = params["sv"]
token, logits, pkv = forward_second(
sv, # self.shark_model[1],
inputs=inputs,
load_inputs=False,
)
detok = self.tokenizer.decode(token)
if debug:
print(
f"[DEBUG] is_first: {is_first} |"
f" token : {token} | detok : {detok}"
)
ret_dict = {
"token": token,
"logits": logits,
"pkv": pkv,
"detok": detok,
}
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -1,408 +0,0 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
ShardedVicunaModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from shark.shark_importer import import_with_fx
from io import BytesIO
from pathlib import Path
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from torch_mlir import TensorPlaceholder
import re
import torch
import torch_mlir
import os
class Vicuna(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
max_num_tokens=512,
device="cuda",
precision="fp32",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
)
return tokenizer
def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
return vicuna_model
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def write_in_dynamic_inputs1(self, module, dynamic_input_size):
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
)
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim_42\)",
"tensor.empty(%dim_42, %dim_42)",
line,
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
new_lines.append(line)
new_module = "\n".join(new_lines)
return new_module
def compile_vicuna_layer(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_value0=None,
past_key_value1=None,
):
if past_key_value0 is None and past_key_value1 is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
model_inputs = (
hidden_states,
attention_mask,
position_ids,
past_key_value0,
past_key_value1,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode
def compile_to_vmfb(self, inputs, layers, is_first=True):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_1_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if idx < 25:
device = "cpu"
else:
device = "cpu"
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
module = SharkInference(
None, device=device, mlir_dialect="tm_tensor"
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
module = SharkInference(
mlirs[idx], device=device, mlir_dialect="tm_tensor"
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
)
placeholder_input1 = (
torch.zeros([1, 1, 4096]),
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
torch.zeros([1, 1], dtype=torch.int64),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0, layers0, is_first=True
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
sharded_model = ShardedVicunaModel(
vicuna_model, shark_layers0, shark_layers1
)
return sharded_model
def compile(self):
return self.get_sharded_model()
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
tokens_generated = []
_past_key_values = None
_token = None
detoks_generated = []
for iteration in range(self.max_num_tokens):
params = {
"prompt": prompt,
"is_first": iteration == 0,
"token": _token,
"past_key_values": _past_key_values,
}
generated_token_op = self.generate_new_token(params=params)
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
if _token == 2:
break
detoks_generated.append(_detok)
tokens_generated.append(_token)
for i in range(len(tokens_generated)):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)
_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)
ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}
print(f" token : {_token} | detok : {_detok}")
return ret_dict
def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass

View File

@@ -3,6 +3,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from pathlib import Path
from shark.shark_downloader import download_public_file
# expects a Path / str as arg
@@ -17,9 +18,23 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
return None
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
return shark_module
def get_vmfb_from_config(
shark_container, model, precision, device, vmfb_path, padding=None
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
)
if padding:
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")

View File

@@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time

View File

@@ -58,11 +58,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = inpaint_obj.generate_images(
args.prompts,
@@ -76,11 +73,12 @@ def main():
args.inpaint_full_res_padding,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
@@ -89,7 +87,10 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -51,11 +51,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = outpaint_obj.generate_images(
args.prompts,
@@ -74,11 +71,12 @@ def main():
args.width,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
@@ -87,7 +85,10 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += f"seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -223,7 +223,8 @@ def lora_train(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, both must not be "
"empty.",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:

View File

@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
@@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,

View File

@@ -42,11 +42,8 @@ def main():
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
if current_batch > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
@@ -56,11 +53,12 @@ def main():
args.width,
args.steps,
args.guidance_scale,
seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
@@ -69,7 +67,12 @@ def main():
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
print("Flag --img_path is required.")
exit()
# When the models get uploaded, it should be default to False.
# When the models get uploaded, it should be defaulted to False.
args.import_mlir = True
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -73,6 +73,7 @@ if __name__ == "__main__":
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -1,57 +1,13 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv-python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['web/index.py'],
pathex=['.'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
@@ -73,11 +29,11 @@ exe = EXE(
a.zipfiles,
a.datas,
[],
name='shark_sd',
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,

View File

@@ -29,6 +29,7 @@ datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('py-cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@@ -0,0 +1,77 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
"./apps/language_models/langchain",
"./apps/language_models/src/pipelines/minigpt4_utils",
]
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("torch-mlir")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",
"minigpt4_utils/configs",
),
(
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
"minigpt4_utils/prompts",
),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]

View File

@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -163,14 +172,24 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"vae",
"vae_encode",
"stencil_adaptor",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
@@ -197,7 +216,11 @@ class SharkifyStableDiffusionModel:
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -209,10 +232,12 @@ class SharkifyStableDiffusionModel:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
@@ -226,7 +251,11 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
@@ -243,7 +272,13 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
@@ -279,7 +314,11 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
@@ -303,7 +342,10 @@ class SharkifyStableDiffusionModel:
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -316,12 +358,43 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
@@ -342,7 +415,25 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
@@ -386,16 +477,23 @@ class SharkifyStableDiffusionModel:
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
@@ -415,9 +513,14 @@ class SharkifyStableDiffusionModel:
)
return shark_cnet, cnet_mlir
def get_unet(self):
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -426,17 +529,26 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.in_channels
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -452,17 +564,33 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -471,15 +599,17 @@ class SharkifyStableDiffusionModel:
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_unet_upscaler(self):
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -502,17 +632,27 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -520,7 +660,12 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
@@ -528,7 +673,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -567,34 +714,47 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model):
def compile_unet_variants(self, model, use_large=False):
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler()
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet()
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet()
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
@@ -616,25 +776,35 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def unet(self):
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
compiled_unet = None
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model)
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
@@ -657,7 +827,11 @@ class SharkifyStableDiffusionModel:
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
@@ -675,7 +849,9 @@ class SharkifyStableDiffusionModel:
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -17,9 +17,13 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
@@ -27,9 +31,12 @@ def get_quantize_model():
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]

View File

@@ -15,6 +15,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -135,6 +145,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
@@ -156,7 +167,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -37,6 +42,11 @@ class InpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -378,6 +388,7 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -408,7 +419,10 @@ class InpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,11 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +43,11 @@ class OutpaintPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -379,6 +389,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -409,7 +420,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -14,6 +14,12 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -38,6 +44,12 @@ class StencilPipeline(StableDiffusionPipeline):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDPMScheduler,
KDPM2DiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -204,6 +216,7 @@ class StencilPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
@@ -230,7 +243,10 @@ class StencilPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -13,6 +13,10 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -81,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -112,7 +121,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -17,6 +17,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
@@ -67,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
low_res_scheduler: Union[
DDIMScheduler,
@@ -78,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -168,7 +180,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -182,15 +197,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -219,6 +245,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -243,6 +270,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -264,7 +292,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# 4. Preprocess image

View File

@@ -15,6 +15,9 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
@@ -57,6 +64,7 @@ class StableDiffusionPipeline:
self.vae = None
self.text_encoder = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
@@ -66,7 +74,8 @@ class StableDiffusionPipeline:
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
self.tokenizer = get_tokenizer()
except:
@@ -81,7 +90,8 @@ class StableDiffusionPipeline:
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder = self.sd_model.clip()
else:
@@ -114,6 +124,24 @@ class StableDiffusionPipeline:
del self.unet
self.unet = None
def load_unet_512(self):
if self.unet_512 is not None:
return
if self.import_mlir or self.use_lora:
self.unet_512 = self.sd_model.unet(use_large=True)
else:
try:
self.unet_512 = get_unet(use_large=True)
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet_512 = self.sd_model.unet(use_large=True)
def unload_unet_512(self):
del self.unet_512
self.unet_512 = None
def load_vae(self):
if self.vae is not None:
return
@@ -203,7 +231,10 @@ class StableDiffusionPipeline:
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
@@ -222,16 +253,28 @@ class StableDiffusionPipeline:
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -254,6 +297,7 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -275,6 +319,10 @@ class StableDiffusionPipeline:
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
@@ -359,16 +407,21 @@ class StableDiffusionPipeline:
prompt (`str` or `list(int)`):
prompt to be encoded
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
The prompt or prompts not to guide the image generation.
Ignored when not using guidance
(i.e., ignored if `guidance_scale` is less than `1`).
model_max_length (int):
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
SHARK: pass the max length instead of relying on
pipe.tokenizer.model_max_length
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not,
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
SHARK: must be set to True as we always expect neg embeddings
(defaulted to True)
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
The max multiple length of prompt embeddings compared to the
max output length of text encoder.
SHARK: max_embeddings_multiples>1 produce a tensor shape error
(defaulted to 1)
num_images_per_prompt (`int`):
number of images that should be generated per prompt
SHARK: num_images_per_prompt is not used (defaulted to 1)
@@ -387,9 +440,11 @@ class StableDiffusionPipeline:
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: "
f"{negative_prompt} has batch size {len(negative_prompt)}, "
f"but `prompt`: {prompt} has batch size {batch_size}. "
f"Please make sure that passed `negative_prompt` matches "
"the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
@@ -402,16 +457,43 @@ class StableDiffusionPipeline:
)
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = text_embeddings.shape
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# text_embeddings = (
# text_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
if do_classifier_free_guidance:
# SHARK: we are not using num_images_per_prompt
# bs_embed, seq_len, _ = uncond_embeddings.shape
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = (
# uncond_embeddings.repeat(
# 1,
# num_images_per_prompt,
# 1
# )
# )
# uncond_embeddings = (
# uncond_embeddings.view(
# bs_embed * num_images_per_prompt,
# seq_len,
# -1
# )
# )
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if text_embeddings.shape[1] > model_max_length:
pad = (0, 0) * (len(text_embeddings.shape) - 2)
pad = pad + (0, 512 - text_embeddings.shape[1])
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
# SHARK: Report clip inference time
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
@@ -446,7 +528,8 @@ re_attention = re.compile(
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
@@ -673,6 +756,12 @@ def get_unweighted_text_embeddings(
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
@@ -764,6 +853,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
@@ -776,6 +869,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"

View File

@@ -8,6 +8,9 @@ from diffusers import (
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
@@ -38,9 +41,28 @@ def get_schedulers(model_id):
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
schedulers[
"DPMSolverMultistep++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
@@ -62,5 +84,21 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverSinglestep"
] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -28,6 +28,8 @@ from apps.stable_diffusion.src.utils.utils import (
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
parse_seed_input,
batch_seeds,
get_path_stem,
get_extended_name,
get_generated_imgs_path,
@@ -37,4 +39,5 @@ from apps.stable_diffusion.src.utils.utils import (
get_generation_text_info,
update_lora_weight,
resize_stencil,
_compile_module,
)

View File

@@ -3,7 +3,9 @@ from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
from shark.parser import shark_args
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")

View File

@@ -5,4 +5,7 @@
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]

View File

@@ -116,7 +116,7 @@ def load_lower_configs(base_model_id=None):
else:
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
else:
if not spec or spec in ["rdna3", "sm_80"]:
if not spec or spec in ["sm_80"]:
if (
version in ["v2_1", "v2_1base"]
and args.height == 768
@@ -125,10 +125,38 @@ def load_lower_configs(base_model_id=None):
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
elif spec in ["rdna3"] and version in [
"v2_1",
"v2_1base",
"v1_4",
"v1_5",
]:
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.max_length}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}_{args.width}x{args.height}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}_"
f"{args.width}x{args.height}.json"
)
else:
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
config_name = (
f"{args.annotation_model}_"
f"{version}_"
f"{args.precision}_"
f"{device}_"
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
@@ -173,9 +201,22 @@ def dump_after_mlir(input_mlir, use_winograd):
device, device_spec_args = get_device_args()
if use_winograd:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
)
else:
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)
dump_module = ireec.compile_str(
input_mlir,

View File

@@ -19,48 +19,56 @@ p = argparse.ArgumentParser(
)
##############################################################################
### Stable Diffusion Params
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompts",
nargs="+",
default=["cyberpunk forest by Salvador Dali"],
help="text of which images to be generated.",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smokes coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=["trees, green"],
help="text you don't want to see in the generated image.",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--img_path",
type=str,
help="Path to the image input for img2img/inpainting",
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
type=str,
default=-1,
help="the seed to use. -1 for a random one.",
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
@@ -68,7 +76,7 @@ p.add_argument(
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `batch_count`.",
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
@@ -76,7 +84,7 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the height of the output image.",
help="The height of the output image.",
)
p.add_argument(
@@ -84,77 +92,86 @@ p.add_argument(
type=int,
default=512,
choices=range(128, 769, 8),
help="the width of the output image.",
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="the value to be used for noise level of upscaler.",
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="the strength of change applied on the given input image for img2img",
help="The strength of change applied on the given input image for "
"img2img.",
)
##############################################################################
### Stable Diffusion Training Params
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The no. of steps to train",
help="The number of steps to train.",
)
##############################################################################
### Inpainting and Outpainting Params
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting",
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture",
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
@@ -162,7 +179,7 @@ p.add_argument(
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding",
help="Number of pixels for only masked padding.",
)
p.add_argument(
@@ -170,7 +187,7 @@ p.add_argument(
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting",
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
@@ -178,89 +195,92 @@ p.add_argument(
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting",
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend left for outpainting",
help="If expend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend right for outpainting",
help="If expend right for outpainting.",
)
p.add_argument(
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend top for outpainting",
help="If expend top for outpainting.",
)
p.add_argument(
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If expend bottom for outpainting",
help="If expend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0)",
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
### Model Config and Usage Params
# Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
"--device", type=str, default="vulkan", help="Device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
help="Attempts to load the model from a precompiled flat-buffer "
"and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
help="Saves the compiled flat-buffer to the local directory.",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
@@ -274,28 +294,42 @@ p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="number of batch to be generated with random seeds in single execution",
help="Number of batches to be generated with random seeds in "
"single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
@@ -309,7 +343,8 @@ p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
@@ -323,14 +358,15 @@ p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption",
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
@@ -343,209 +379,221 @@ p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
help="Use standalone LoRA weight using a HF ID or a checkpoint "
"file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--ondemand",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM",
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
##############################################################################
### IREE - Vulkan supported flags
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
### Misc. Debug and Optimization flags
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
help="When enabled call amdllpc to get ISA dumps. "
"Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
help="Flag for inserting debug frames between iterations "
"for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
##############################################################################
### Web UI flags
# Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the progress bar animation during image generation",
help="Flag for removing the progress bar animation during "
"image generation.",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="one of: [api, app, web]",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for enabling rest API",
help="Flag for enabling rest API.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the output gallery tab, and avoid exposing images under --output_dir in the UI",
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether the output gallery tab in the UI should follow symlinks when listing subdirectorys under --output_dir",
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
##############################################################################
### SD model auto-annotation flags
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
@@ -559,33 +607,43 @@ p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file",
help="Save annotated mlir file.",
)
##############################################################################
### SD model auto-tuner flags
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning",
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
args, unknown = p.parse_known_args()
if args.import_debug:

View File

@@ -8,7 +8,12 @@ from datetime import datetime as dt
from csv import DictWriter
from pathlib import Path
import numpy as np
from random import randint
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
import tempfile
import torch
from safetensors.torch import load_file
@@ -17,7 +22,9 @@ from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
@@ -31,6 +38,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
import requests
from io import BytesIO
from omegaconf import OmegaConf
from cpuinfo import get_cpu_info
def get_extended_name(model_name):
@@ -47,6 +55,7 @@ def get_vmfb_path_name(model_name):
def _load_vmfb(shark_module, vmfb_path, model, precision):
model = "vae" if "base_vae" in model or "vae_encode" in model else model
model = "unet" if "stencil" in model else model
model = "unet" if "unet512" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module.load_module(vmfb_path, extra_args=extra_args)
@@ -78,7 +87,9 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
def get_shark_model(tank_url, model_name, extra_args=None):
if extra_args is None:
extra_args = []
from shark.parser import shark_args
# Set local shark_tank cache directory.
@@ -110,12 +121,15 @@ def compile_through_fx(
save_dir=tempfile.gettempdir(),
debug=False,
generate_vmfb=True,
extra_args=[],
extra_args=None,
base_model_id=None,
model_name=None,
precision=None,
return_mlir=False,
device=None,
):
if extra_args is None:
extra_args = []
if not return_mlir and model_name is not None:
vmfb_path = get_vmfb_path_name(extended_model_name)
if os.path.isfile(vmfb_path):
@@ -145,7 +159,10 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
if (
"unet" in model_name.split("_")[0]
or "unet_512" in model_name.split("_")[0]
):
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
@@ -153,7 +170,7 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
device=args.device if device is None else device,
mlir_dialect="tm_tensor",
)
if generate_vmfb:
@@ -167,10 +184,7 @@ def compile_through_fx(
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
@@ -198,13 +212,15 @@ def get_device_mapping(driver, key_combination=3):
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
@@ -218,7 +234,7 @@ def get_device_mapping(driver, key_combination=3):
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
@@ -231,10 +247,12 @@ def get_device_mapping(driver, key_combination=3):
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
@@ -242,7 +260,8 @@ def map_device_to_name_path(device, key_combination=3):
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
@@ -265,10 +284,21 @@ def set_init_device_flags():
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
f"Found device {device_name}. Using target triple "
f"{args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
@@ -293,13 +323,24 @@ def set_init_device_flags():
if (
args.precision != "fp16"
or args.height not in [512, 768]
or (args.height == 512 and args.width != 512)
or (args.height == 768 and args.width != 768)
or (args.height == 512 and args.width not in [512, 768])
or (args.height == 768 and args.width not in [512, 768])
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
args.height != args.width
and "rdna2" in args.iree_vulkan_target_triple
and base_model_id
not in [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
]
):
args.use_tuned = False
elif base_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
@@ -354,7 +395,8 @@ def set_init_device_flags():
if args.use_tuned:
print(
f"Using tuned models for {base_model_id}(fp16) on device {args.device}."
f"Using tuned models for {base_model_id}(fp16) on "
f"device {args.device}."
)
else:
print("Tuned models are currently not supported for this setting.")
@@ -412,8 +454,17 @@ def get_available_devices():
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -421,9 +472,14 @@ def get_available_devices():
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
@@ -500,10 +556,10 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
@@ -524,7 +580,10 @@ def convert_original_vae(vae_checkpoint):
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
@@ -645,7 +704,7 @@ def update_lora_weight(model, use_lora, model_name):
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# helps to maintain mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
@@ -662,14 +721,17 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
# Update JSON data to contain an entry mapping model_to_run with
# base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
@@ -677,6 +739,56 @@ def sanitize_seed(seed):
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(
type(seed) is int for seed in seed_input
):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(
seed_input: str | list | int, batch_count: int, repeatable=False
):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
seed_random(str(seeds))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds
# clear all the cached objects to recompile cleanly.
def clear_all():
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
@@ -687,7 +799,8 @@ def clear_all():
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# Temporary workaround of deleting yaml files to incorporate
# diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
@@ -716,7 +829,9 @@ def get_generated_imgs_todays_subdir() -> str:
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info={}):
def save_output_img(output_img, img_seed, extra_info=None):
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
@@ -724,14 +839,20 @@ def save_output_img(output_img, img_seed, extra_info={}):
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = Path(os.path.basename(args.ckpt_loc)).stem
img_vae = None
if args.custom_vae:
img_vae = Path(os.path.basename(args.custom_vae)).stem
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -742,17 +863,30 @@ def save_output_img(output_img, img_seed, extra_info={}):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
f"{args.prompts[0]}"
f"\nNegative prompt: {args.negative_prompts[0]}"
f"\nSteps: {args.steps},"
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
f"[ERROR] Format {args.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
@@ -766,12 +900,17 @@ def save_output_img(output_img, img_seed, extra_info={}):
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
"LORA": img_lora,
}
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
@@ -785,16 +924,27 @@ def save_output_img(output_img, img_seed, extra_info={}):
def get_generation_text_info(seeds, device):
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={args.steps}, "
f"guidance_scale={args.guidance_scale}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={args.height}x{args.width}, "
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"
)
return text_output
# For stencil, the input image can be of any size but we need to ensure that
# it conforms with our model contraints :-
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.

View File

@@ -0,0 +1,51 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web\\index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='studio_bundle',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='studio_bundle',
)

View File

@@ -1,7 +1,13 @@
from multiprocessing import Process, freeze_support
import os
import sys
import transformers # ensures inclusion in pysintaller exe generation
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
import torch_mlir
import shutil
import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
import apps.stable_diffusion.web.utils.global_obj as global_obj
@@ -20,11 +26,16 @@ def launch_app(address):
window = Tk()
# getting screen width and height of display
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)
@@ -38,7 +49,10 @@ if __name__ == "__main__":
img2img_api,
upscaler_api,
inpaint_api,
outpaint_api,
llm_chat_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -49,23 +63,36 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route(
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
# )
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
sys.exit(0)
import gradio as gr
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
clear_gradio_tmp_imgs_folder,
config_gradio_tmp_imgs_folder,
)
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# Clear all gradio tmp images from the last session
clear_gradio_tmp_imgs_folder()
# Create custom models folders if they don't exist
create_custom_models_folders()
def resource_path(relative_path):
@@ -88,6 +115,7 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
@@ -133,6 +161,7 @@ if __name__ == "__main__":
modelmanager_sendto_outpaint,
modelmanager_sendto_upscaler,
stablelm_chat,
minigpt4_web,
outputgallery_web,
outputgallery_tab_select,
outputgallery_watch,
@@ -192,14 +221,8 @@ if __name__ == "__main__":
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=5):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
stablelm_chat.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
lora_train_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
# extra output gallery configuration
@@ -213,6 +236,16 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()
# send to buttons
register_button_click(

View File

@@ -74,7 +74,12 @@ from apps.stable_diffusion.web.ui.model_manager import (
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,

View File

@@ -117,16 +117,12 @@ body {
padding: 0 var(--size-4) !important;
}
.container {
background-color: black !important;
padding-top: var(--size-5) !important;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
@@ -227,17 +223,36 @@ footer {
}
/* Hide the download icon from the nod logo */
#top_logo .download {
#top_logo button {
display: none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* output gallery tab */
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
}
#output_refresh_button {
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;

View File

@@ -0,0 +1,251 @@
import gradio as gr
import torch
import os
from pathlib import Path
from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.langchain.enums import (
DocumentChoices,
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from apps.stable_diffusion.src import args
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
sharkModel = 0
h2ogpt_model = 0
# NOTE: Each `model_name` should have its own start message
start_message = """
SHARK DocuChat
Chat with an AI, contextualized with provided files.
"""
def create_prompt(history):
system_message = start_message
conversation = "".join(["".join([item[0], item[1]]) for item in history])
msg = system_message + conversation
msg = msg.strip()
return msg
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global h2ogpt_model
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
if h2ogpt_model == 0:
if "cuda" in device:
shark_device = "cuda"
elif "sync" in device:
shark_device = "cpu"
elif "task" in device:
shark_device = "cpu"
elif "vulkan" in device:
shark_device = "vulkan"
else:
print("unrecognized device")
device = "cpu" if shark_device == "cpu" else "cuda"
args.device = shark_device
args.precision = precision
from apps.language_models.langchain.gen import Langchain
langchain = Langchain(device, precision)
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
load_4bit=True
if device == "cuda"
else False, # load model in 4bit if device is cuda to save memory
load_gptq="",
use_safetensors=False,
infer_devices=True,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
inference_server="",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
gpu_id=0,
reward_type=None,
local_files_only=False,
resume_download=True,
use_auth_token=False,
trust_remote_code=True,
offload_folder=None,
compile_model=False,
verbose=False,
)
model_state = dict(
model=h2ogpt_model,
tokenizer=h2ogpt_tokenizer,
device=device,
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
lora_weights="",
inference_server="",
prompt_type=None,
prompt_dict=None,
)
prompt = create_prompt(history)
output = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
iinput="",
context="",
stream_output=True,
prompt_type="prompt_answer",
prompt_dict={
"promptA": "",
"promptB": "",
"PreInstruct": "<|prompt|>",
"PreInput": None,
"PreResponse": "<|answer|>",
"terminate_response": [
"<|prompt|>",
"<|answer|>",
"<|endoftext|>",
],
"chat_sep": "<|endoftext|>",
"chat_turn_sep": "<|endoftext|>",
"humanstr": "<|prompt|>",
"botstr": "<|answer|>",
"generates_leading_space": False,
},
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=1,
max_new_tokens=256,
min_new_tokens=0,
early_stopping=False,
max_time=180,
repetition_penalty=1.07,
num_return_sequences=1,
do_sample=False,
chat=True,
instruction_nochat=prompt,
iinput_nochat="",
langchain_mode="UserData",
langchain_action=LangChainAction.QUERY.value,
top_k_docs=3,
chunk=True,
chunk_size=512,
document_choice=[DocumentChoices.All_Relevant.name],
concurrency_count=1,
memory_restriction_level=2,
raise_generate_gpu_exceptions=False,
chat_context="",
use_openai_embedding=False,
use_openai_model=False,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
db_type="chroma",
n_jobs=-1,
first_para=False,
max_max_time=60 * 2,
model_state0=model_state,
model_lock=True,
user_path=userpath_selector.value,
)
for partial_text in output:
history[-1][1] = partial_text["response"]
yield history
return history
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
with gr.Row():
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, device, precision],
outputs=[chatbot],
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -50,7 +50,7 @@ def img2img_inf(
steps: int,
strength: float,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -66,6 +66,7 @@ def img2img_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -104,7 +105,8 @@ def img2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -132,7 +134,8 @@ def img2img_inf(
image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
f"Shark schedulers are not supported. Switching to EulerDiscrete "
f"scheduler"
)
args.scheduler = "EulerDiscrete"
cpu_scheduling = not args.scheduler.startswith("Shark")
@@ -227,13 +230,14 @@ def img2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"STRENGTH": strength}
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -244,16 +248,18 @@ def img2img_inf(
steps,
strength,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
@@ -262,7 +268,7 @@ def img2img_inf(
else:
save_output_img(
out_imgs[0],
img_seed,
seeds[current_batch],
extra_info,
)
generated_imgs.extend(out_imgs)
@@ -307,7 +313,9 @@ def img2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = img2img_inf(
@@ -339,7 +347,12 @@ def img2img_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -357,13 +370,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
i2i_model_info = (str(get_custom_model_path())).replace(
"\\", "\n\\"
)
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
img2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=i2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -374,13 +395,23 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
"\\", "\n\\"
)
i2i_vae_info = f"VAE Path: {i2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=i2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -392,13 +423,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -407,7 +438,8 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
source="upload",
tool="sketch",
type="pil",
).style(height=300)
height=300,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Row():
@@ -470,15 +502,24 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -528,16 +569,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
strength = gr.Slider(
0,
1,
value=args.strength,
step=0.01,
label="Denoising Strength",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
@@ -561,6 +604,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -570,10 +618,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
@@ -582,16 +631,15 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -599,9 +647,12 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=2,
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -643,9 +694,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(

View File

@@ -49,7 +49,7 @@ def inpaint_inf(
inpaint_full_res_padding: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -64,6 +64,7 @@ def inpaint_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: int,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -92,7 +93,8 @@ def inpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -179,14 +181,15 @@ def inpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
image = image_dict["image"]
mask_image = image_dict["mask"]
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -199,25 +202,27 @@ def inpaint_inf(
inpaint_full_res_padding,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Inpaint", i + 1, batch_count, batch_size
"Inpaint", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output
@@ -257,7 +262,9 @@ def inpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
init_image = decode_base64_to_image(InputData["image"])
mask = decode_base64_to_image(InputData["mask"])
@@ -278,7 +285,7 @@ def inpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -288,7 +295,12 @@ def inpaint_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -306,13 +318,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
inpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
inpaint_model_info = (
f"Custom Model Path: {inpaint_model_info}"
)
inpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=inpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -325,13 +347,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
inpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
inpaint_vae_info = f"VAE Path: {inpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=inpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -343,13 +375,13 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
@@ -358,19 +390,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
source="upload",
tool="sketch",
type="pil",
).style(height=350)
height=350,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -460,6 +502,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -469,10 +516,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
@@ -481,16 +529,15 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -498,9 +545,12 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -543,9 +593,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[inpaint_gallery, std_output, inpaint_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),

View File

@@ -3,7 +3,7 @@ import os
import gradio as gr
from PIL import Image
from apps.stable_diffusion.scripts import lora_train
from apps.stable_diffusion.src import prompt_examples, args
from apps.stable_diffusion.src import prompt_examples, args, utils
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -24,15 +24,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
train_lora_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
train_lora_model_info = (
f"Custom Model Path: {train_lora_model_info}"
)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=train_lora_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -43,22 +53,33 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
placeholder="Select 'None' in the Models "
"dropdown on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
# janky fix for overflowing text
train_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
train_lora_info = f"LoRA Path: {train_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA weights to initialize weights",
info=train_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use a "
"standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID to initialize weights",
lines=3,
@@ -74,7 +95,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
@@ -147,7 +168,9 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
value=utils.parse_seed_input(args.seed)[0],
precision=0,
label="Seed",
)
device = gr.Dropdown(
elem_id="device",
@@ -215,7 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
),
],
outputs=[std_output],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
prompt_submit = prompt.submit(**kwargs)

View File

@@ -0,0 +1,193 @@
# ========================================
# Gradio Setting
# ========================================
import gradio as gr
# from apps.language_models.src.pipelines.minigpt4_pipeline import (
# # MiniGPT4,
# CONV_VISION,
# )
from pathlib import Path
chat = None
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return (
None,
gr.update(value=None, interactive=True),
gr.update(
placeholder="Please upload your image first", interactive=False
),
gr.update(value="Upload & Start Chat", interactive=True),
chat_state,
img_list,
)
def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
global chat
if chat is None:
from apps.language_models.src.pipelines.minigpt4_pipeline import (
MiniGPT4,
CONV_VISION,
)
vision_model_precision = precision
if precision in ["int4", "int8"]:
vision_model_precision = "fp16"
vision_model_vmfb_path = Path(
f"vision_model_{vision_model_precision}_{device}.vmfb"
)
qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb")
chat = MiniGPT4(
model_name="MiniGPT4",
hf_model_path=None,
max_new_tokens=30,
device=device,
precision=precision,
_compile=_compile,
vision_model_vmfb_path=vision_model_vmfb_path,
qformer_vmfb_path=qformer_vmfb_path,
)
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return (
gr.update(interactive=False),
gr.update(interactive=True, placeholder="Type and press Enter"),
gr.update(value="Start Chatting", interactive=False),
chat_state,
img_list,
)
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return (
gr.update(
interactive=True, placeholder="Input should not be empty!"
),
chatbot,
chat_state,
)
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return "", chatbot, chat_state
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
llm_message = chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000,
)[0]
print(llm_message)
print("************")
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """<h1 align="center">MultiModal SHARK (experimental)</h1>"""
description = """<h3>Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
"""
# TODO show examples below
with gr.Blocks() as minigpt4_web:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",
interactive=True,
variant="primary",
)
clear = gr.Button("Restart")
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
device = gr.Dropdown(
label="Device",
value="cuda",
# if enabled
# else "Only CUDA Supported for now",
choices=["cuda"],
interactive=False,
)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label="MiniGPT-4")
text_input = gr.Textbox(
label="User",
placeholder="Please upload your image first",
interactive=False,
)
precision = gr.Radio(
label="Precision",
value="int8",
choices=[
"int8",
"fp16",
"fp32",
],
visible=True,
)
_compile = gr.Checkbox(
value=False,
label="Compile",
interactive=True,
)
upload_button.click(
upload_img,
[image, text_input, chat_state, device, precision, _compile],
[image, text_input, upload_button, chat_state, img_list],
)
text_input.submit(
gradio_ask,
[text_input, chatbot, chat_state],
[text_input, chatbot, chat_state],
).then(
gradio_answer,
[chatbot, chat_state, img_list, num_beams, temperature],
[chatbot, chat_state, img_list],
)
clear.click(
gradio_reset,
[chat_state, img_list],
[chatbot, image, text_input, upload_button, chat_state, img_list],
queue=False,
)

View File

@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
def get_civit_list(num_of_models=50):
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
path = (
f"https://civitai.com/api/v1/models?limit="
f"{num_of_models}&types=Checkpoint"
)
headers = {"Content-Type": "application/json"}
raw_json = requests.get(path, headers=headers).json()
models = list(raw_json.items())[0][1]
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
type="value",
label="Model Source",
)
model_numebr = gr.Slider(
model_number = gr.Slider(
1,
100,
value=10,
@@ -111,9 +114,9 @@ with gr.Blocks() as model_web:
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
def get_model_list(model_source, model_numebr):
def get_model_list(model_source, model_number):
if model_source == "Hugging Face":
hf_model_list = get_hf_list(model_numebr)
hf_model_list = get_hf_list(model_number)
models = []
for model in hf_model_list:
# TODO: add model info
@@ -124,7 +127,7 @@ with gr.Blocks() as model_web:
gr.Row.update(visible=True),
)
elif model_source == "Civitai":
civit_model_list = get_civit_list(model_numebr)
civit_model_list = get_civit_list(model_number)
models = []
for model in civit_model_list:
image = get_image_from_model(model)
@@ -148,7 +151,7 @@ with gr.Blocks() as model_web:
get_model_btn.click(
fn=get_model_list,
inputs=[model_source, model_numebr],
inputs=[model_source, model_number],
outputs=[
hf_models,
civit_models,

View File

@@ -29,7 +29,6 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -50,7 +49,7 @@ def outpaint_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -65,6 +64,7 @@ def outpaint_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -92,7 +92,8 @@ def outpaint_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -177,8 +178,10 @@ def outpaint_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
left = True if "left" in directions else False
right = True if "right" in directions else False
@@ -186,9 +189,7 @@ def outpaint_inf(
bottom = True if "down" in directions else False
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -206,25 +207,27 @@ def outpaint_inf(
width,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Outpaint", i + 1, batch_count, batch_size
"Outpaint", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
@@ -264,7 +267,9 @@ def outpaint_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = outpaint_inf(
@@ -287,7 +292,7 @@ def outpaint_api(
custom_model="None",
hf_model_id=InputData["hf_model_id"]
if "hf_model_id" in InputData.keys()
else "stabilityai/stable-diffusion-2-1-base",
else "stabilityai/stable-diffusion-2-inpainting",
custom_vae="None",
precision="fp16",
device=available_devices[0],
@@ -297,7 +302,12 @@ def outpaint_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -315,13 +325,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
outpaint_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
outpaint_model_info = (
f"Custom Model Path: {outpaint_model_info}"
)
outpaint_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=outpaint_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -334,13 +354,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
"https://civitai.com/api/download/models/3433",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
outpaint_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
outpaint_vae_info = f"VAE Path: {outpaint_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=outpaint_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -352,31 +382,42 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
outpaint_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -488,6 +529,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -497,10 +544,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
@@ -509,16 +557,15 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -526,9 +573,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -571,9 +621,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[outpaint_gallery, std_output, outpaint_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),

View File

@@ -1,6 +1,8 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.stable_diffusion.src import args
@@ -9,9 +11,6 @@ from apps.stable_diffusion.src.utils import (
get_generated_imgs_todays_subdir,
)
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
from apps.stable_diffusion.web.utils.gradio_configs import (
gradio_tmp_galleries_folder,
)
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
@@ -41,14 +40,14 @@ def output_subdirs() -> list[str]:
)
]
# It is less confusing to always including the subdir that will take any images generated
# today even if it doesn't exist yet
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that that the date named ones we probably created in this or
# previous sessions come first, sorted with the most recent first. Other subdirs are listed
# after.
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
@@ -63,26 +62,14 @@ def output_subdirs() -> list[str]:
return result_paths
# clear zero length temporary files that gradio 3.22.0 buggily creates
# TODO: remove once gradio is upgraded to or past 3.32.0
def clear_zero_length_temps():
zero_length_temps = [
os.path.join(root, file)
for root, dirs, files in os.walk(gradio_tmp_galleries_folder)
for file in files
if os.path.getsize(os.path.join(root, file)) == 0
]
for file in zero_length_temps:
os.remove(file)
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_web:
nod_logo = Image.open(nodlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue: https://github.com/gradio-app/gradio/issues/2907
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
@@ -104,34 +91,56 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
).style(columns=4)
gallery.DEFAULT_TEMP_DIR = gradio_tmp_galleries_folder
columns=2,
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row():
with gr.Column(scale=16, min_width=160):
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
).style(container=False)
elem_classes="dropdown_no_container",
)
with gr.Column(
scale=1, min_width=32, elem_id="output_refresh_button"
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
).style(size="sm")
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename", value="None", interactive=False
).style(show_copy_button=True)
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
@@ -150,36 +159,40 @@ with gr.Blocks() as outputgallery_web:
value="Txt2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_inpaint = gr.Button(
value="Inpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_outpaint = gr.Button(
value="Outpaint",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
outputgallery_sendto_upscaler = gr.Button(
value="Upscaler",
interactive=False,
elem_classes="outputgallery_sendto",
).style(size="sm")
size="sm",
)
# --- Event handlers
def on_clear_gallery():
clear_zero_length_temps()
return [
gr.Gallery.update(
value=[],
@@ -209,17 +222,32 @@ with gr.Blocks() as outputgallery_web:
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up to date subdirectory list
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most recent valid one
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, new_subdir)}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown.update(
@@ -238,18 +266,22 @@ with gr.Blocks() as outputgallery_web:
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab has even been selected
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as new images only go there
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
clear_zero_length_temps()
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)} - {status}"
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
@@ -264,19 +296,27 @@ with gr.Blocks() as outputgallery_web:
),
]
else:
# otherwise change nothing, (only untyped gradio gr.update() does this)
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for the current subdirectory
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
return [
filename,
list(map(list, params["parameters"].items())),
]
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
@@ -286,7 +326,8 @@ with gr.Blocks() as outputgallery_web:
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether an image is selected
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
gr.Button.update(interactive=exists),
@@ -295,17 +336,23 @@ with gr.Blocks() as outputgallery_web:
gr.Button.update(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh to populate
# the subdirectory select box and the images from the most recent subdirectory.
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls' definitions
# as when you refresh the browser you always get what was *initially* set, which
# won't include any new subdirectories or images that might have created since
# the application was started. Doing it this way means a browser refresh/reload
# always gets the most up to date data.
def on_select_tab(subdir_paths):
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("")
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
@@ -314,13 +361,14 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# Unfortunately as of gradio 3.22.0 gr.update against Galleries doesn't support
# things set with .style, nor the elem_classes kwarg so we have to directly set
# things up via JavaScript if we want the client to take notice of any of our
# changes to the number of columns after it decides to put them back to the
# original number when we change something
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
@@ -337,32 +385,36 @@ with gr.Blocks() as outputgallery_web:
# --- Wire handlers up to the actions
# - Many actions reset the number of columns shown in the gallery on the browser end,
# so we have to set them back to what we think they should be after the initial
# action.
# - None of the actions on this tab trigger inference, and we want the user to be able
# to do them whilst other tabs have ongoing inference running. Waiting in the queue
# behind inference jobs would mean the UI can't fully respond until the inference tasks
# complete, hence queue=False on all of these.
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if I don't put an output here
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real timeout length for the
# number of columns to actually be applied. Not really sure why, maybe something has
# to finish animating?
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the gallery avoids current
# images being shown replacing piecemeal and prevents weirdness and errors if the user
# selects an image during the replacement phase.
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
@@ -379,6 +431,10 @@ with gr.Blocks() as outputgallery_web:
queue=False,
).then(**set_gallery_columns_immediate)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
@@ -417,6 +473,7 @@ with gr.Blocks() as outputgallery_web:
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)

View File

@@ -6,13 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
from datetime import datetime as dt
def user(message, history):
@@ -24,53 +18,139 @@ sharkModel = 0
sharded_model = 0
vicuna_model = 0
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
def chat(curr_system_message, history, model, device, precision):
print(f"In chat for {model}")
global sharded_model
global past_key_values
global vicuna_model
if "vicuna" in model:
from apps.language_models.src.pipelines.vicuna_pipeline import (
Vicuna,
)
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
def create_prompt(model_name, history):
system_message = start_message[model_name]
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
msg = system_message + conversation
msg = msg.strip()
return msg
def set_vicuna_model(model):
global vicuna_model
vicuna_model = model
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
prompt = create_prompt(model_name, history)
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence
return history
# else Model is StableLM
@@ -82,48 +162,151 @@ def chat(curr_system_message, history, model, device, precision):
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
"StableLM"
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the current system message and conversation history
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
curr_system_message = start_message
messages = curr_system_message + "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
generate_kwargs = dict(prompt=messages)
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
for new_text in words_list:
# print(new_text)
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
def llm_chat_api(InputData: dict):
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
)
model = gr.Dropdown(
label="Select Model",
value="TheBloke/vicuna-7B-1.1-HF",
choices=[
"stabilityai/stablelm-tuned-alpha-3b",
"TheBloke/vicuna-7B-1.1-HF",
],
value=model_choices[0],
choices=model_choices,
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -134,14 +317,24 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
)
precision = gr.Radio(
label="Precision",
value="fp32",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=True,
)
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
@@ -149,7 +342,8 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
).style(container=False)
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)

View File

@@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils import (
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
@@ -45,7 +46,7 @@ def txt2img_inf(
width: int,
steps: int,
guidance_scale: float,
seed: int,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -60,6 +61,7 @@ def txt2img_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -86,7 +88,8 @@ def txt2img_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -137,6 +140,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
@@ -174,12 +178,13 @@ def txt2img_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
text_output = ""
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
@@ -188,25 +193,27 @@ def txt2img_inf(
width,
steps,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
text_output = get_generation_text_info(seeds, device)
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], img_seed)
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image", i + 1, batch_count, batch_size
"Text-to-Image", current_batch + 1, batch_count, batch_size
)
return generated_imgs, text_output, ""
@@ -235,7 +242,9 @@ def txt2img_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}.'
)
res = txt2img_inf(
InputData["prompt"],
@@ -261,7 +270,12 @@ def txt2img_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Convert Generator to Subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -279,15 +293,25 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
# janky fix for overflowing text
t2i_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
t2i_model_info = (
f"Custom Model Path: {t2i_model_info}"
)
txt2img_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=t2i_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -298,13 +322,21 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the dropdown "
"on the left and enter model ID here.",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL.",
lines=3,
)
# janky fix for overflowing text
t2i_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
t2i_vae_info = f"VAE Path: {t2i_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -325,26 +357,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -357,7 +398,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
value=args.scheduler,
choices=scheduler_list,
)
with gr.Group():
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
@@ -402,16 +443,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
@@ -436,10 +479,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Batch Size",
interactive=True,
)
stop_batch = gr.Button("Stop Batch")
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
@@ -448,17 +496,15 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -473,9 +519,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -515,9 +564,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
@@ -550,6 +600,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
txt2img_png_info_img,
@@ -563,5 +616,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
height,
txt2img_custom_model,
txt2img_hf_model_id,
lora_weights,
lora_hf_id,
custom_vae,
],
)

View File

@@ -42,7 +42,7 @@ def upscaler_inf(
steps: int,
noise_level: int,
guidance_scale: float,
seed: int,
seed: str,
batch_count: int,
batch_size: int,
scheduler: str,
@@ -57,6 +57,7 @@ def upscaler_inf(
lora_weights: str,
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -88,7 +89,8 @@ def upscaler_inf(
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
"Please provide either custom model or huggingface model ID, "
"both must not be empty.",
)
if "civitai" in hf_model_id:
args.ckpt_loc = hf_model_id
@@ -175,12 +177,13 @@ def upscaler_inf(
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
extra_info = {"NOISE LEVEL": noise_level}
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
if current_batch > 0:
img_seed = utils.sanitize_seed(-1)
low_res_img = image
high_res_img = Image.new("RGB", (height * 4, width * 4))
@@ -197,11 +200,12 @@ def upscaler_inf(
steps,
noise_level,
guidance_scale,
img_seed,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
@@ -211,27 +215,40 @@ def upscaler_inf(
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
text_output += (
f"\nsteps={steps}, "
f"noise_level={noise_level}, "
f"guidance_scale={guidance_scale}, "
f"seed={seeds[:current_batch + 1]}"
)
text_output += (
f"\ninput size={height}x{width}, "
f"output size={height*4}x{width*4}, "
f"batch_count={batch_count}, "
f"batch_size={batch_size}, "
f"max_length={args.max_length}\n"
)
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(high_res_img, img_seed, extra_info)
save_output_img(high_res_img, seeds[current_batch], extra_info)
generated_imgs.append(high_res_img)
seeds.append(img_seed)
global_obj.get_sd_obj().log += "\n"
yield generated_imgs, global_obj.get_sd_obj().log, status_label(
yield generated_imgs, text_output, status_label(
"Upscaler", current_batch + 1, batch_count, batch_size
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
text_output += global_obj.get_sd_obj().log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
yield generated_imgs, text_output, ""
@@ -269,7 +286,9 @@ def upscaler_api(
InputData: dict,
):
print(
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
f'Prompt: {InputData["prompt"]}, '
f'Negative Prompt: {InputData["negative_prompt"]}, '
f'Seed: {InputData["seed"]}'
)
init_image = decode_base64_to_image(InputData["init_images"][0])
res = upscaler_inf(
@@ -298,7 +317,11 @@ def upscaler_api(
lora_weights="None",
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
)
# Converts generator type to subscriptable
res = next(res)
return {
"images": encode_pil_to_base64(res[0]),
"parameters": {},
@@ -316,13 +339,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=50)
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
# janky fix for overflowing text
upscaler_model_info = (
str(get_custom_model_path())
).replace("\\", "\n\\")
upscaler_model_info = (
f"Custom Model Path: {upscaler_model_info}"
)
upscaler_custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {get_custom_model_path()})",
label=f"Models",
info=upscaler_model_info,
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
@@ -335,13 +368,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
placeholder="Select 'None' in the Models dropdown "
"on the left and enter model ID here "
"e.g: SG161222/Realistic_Vision_V1.3, "
"https://civitai.com/api/download/models/15236",
value="",
label="HuggingFace Model ID or Civitai model download URL",
label="HuggingFace Model ID or Civitai model "
"download URL",
lines=3,
)
# janky fix for overflowing text
upscaler_vae_info = (
str(get_custom_model_path("vae"))
).replace("\\", "\n\\")
upscaler_vae_info = f"VAE Path: {upscaler_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
label=f"Custom VAE Models",
info=upscaler_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
@@ -353,31 +396,42 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=1,
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=1,
lines=2,
elem_id="negative_prompt_box",
)
upscaler_init_image = gr.Image(
label="Input Image", type="pil"
).style(height=300)
label="Input Image",
type="pil",
height=300,
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
@@ -468,6 +522,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Batch Count",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
batch_size = gr.Slider(
1,
4,
@@ -477,10 +536,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
interactive=False,
visible=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
seed = gr.Number(
value=args.seed, precision=0, label="Seed"
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
@@ -489,16 +549,15 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
choices=available_devices,
)
with gr.Row():
with gr.Column(scale=2):
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
with gr.Column(scale=6):
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -506,9 +565,12 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Generated images",
show_label=False,
elem_id="gallery",
).style(columns=[2], object_fit="contain")
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"Images will be saved at {get_generated_imgs_path()}",
value=f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
@@ -548,9 +610,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
lora_weights,
lora_hf_id,
ondemand,
repeatable_seeds,
],
outputs=[upscaler_gallery, std_output, upscaler_status],
show_progress=args.progress_bar,
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),

View File

@@ -39,8 +39,16 @@ scheduler_list_cpu_only = [
"LMSDiscrete",
"KDPM2Discrete",
"DPMSolverMultistep",
"DPMSolverMultistep++",
"DPMSolverMultistepKarras",
"DPMSolverMultistepKarras++",
"EulerDiscrete",
"EulerAncestralDiscrete",
"DEISMultistep",
"KDPM2AncestralDiscrete",
"DPMSolverSinglestep",
"DDPM",
"HeunDiscrete",
]
scheduler_list = scheduler_list_cpu_only + [
"SharkEulerDiscrete",
@@ -50,6 +58,7 @@ predefined_models = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"xzuyn/PhotoMerge",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
@@ -58,6 +67,7 @@ predefined_models = [
predefined_paint_models = [
"runwayml/stable-diffusion-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"xzuyn/PhotoMerge-inpainting",
]
predefined_upscaler_models = [
"stabilityai/stable-diffusion-x4-upscaler",
@@ -79,7 +89,8 @@ def create_custom_models_folders():
else:
if not os.path.isdir(args.ckpt_dir):
sys.exit(
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
f"Invalid --ckpt_dir argument, "
f"{args.ckpt_dir} folder does not exists."
)
for root in dir:
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)

View File

@@ -1,60 +1,54 @@
import os
import shutil
import tempfile
import gradio
from time import time
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
gradio_tmp_galleries_folder = os.path.join(gradio_tmp_imgs_folder, "galleries")
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
# Clear all gradio tmp images
def clear_gradio_tmp_imgs_folder():
if not os.path.exists(gradio_tmp_imgs_folder):
return
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
# clear all gradio tmp files created by generation galleries
print(
"Clearing gradio temporary image files from a prior run. This may take some time..."
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
image_files = [
filename
for filename in os.listdir(gradio_tmp_imgs_folder)
if os.path.isfile(os.path.join(gradio_tmp_imgs_folder, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
for filename in image_files:
os.remove(gradio_tmp_imgs_folder + filename)
print(
f"Clearing generation temporary image files took {time() - cleanup_start:4f} seconds"
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
else:
print("no generation temporary files to clear")
# Clear all gradio tmp files created by output galleries
if os.path.exists(gradio_tmp_galleries_folder):
cleanup_start = time()
shutil.rmtree(gradio_tmp_galleries_folder, ignore_errors=True)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing output gallery temporary image files took {time() - cleanup_start:4f} seconds"
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
print("no output gallery temporary files to clear")
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
def save_pil_to_file(pil_image, dir=None):
if not os.path.exists(gradio_tmp_imgs_folder):
os.mkdir(gradio_tmp_imgs_folder)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
)
pil_image.save(file_obj)
return file_obj
# Register save_pil_to_file override
gradio.processing_utils.save_pil_to_file = save_pil_to_file
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")

View File

@@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def parse_csv(image_filename: str):
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
# headers, and then match up the return list for each row with our guess at which column format
# the file is using.
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename. We then exclude the filename from the output, hence the -1's.
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
matches = [
humanize(row)
for row in csv.reader(open(csv_filename, "r", newline=""))
if row
and humanizable(row)
and os.path.basename(image_filename) in row[-1]
]
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -8,6 +8,9 @@ from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,

View File

@@ -50,7 +50,22 @@ PARAMS_FORMATS = {
},
}
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
@@ -76,6 +91,18 @@ def compact(metadata: dict) -> dict:
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (
not result["VAE"] or result["VAE"] == "None"
):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (
not result["LoRA"] or result["LoRA"] == "None"
):
result.pop("LoRA")
return result
@@ -97,19 +124,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we use the longest. Then we swap keys in the
# metadata that match keys in the format for the friendlier name that we
# have set in the format value
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_LONGEST
format = PARAMS_FORMAT_CURRENT
return {
format[key]: value
for (key, value) in metadata.items()
if key in format.keys()
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -62,6 +62,82 @@ def parse_generation_parameters(x: str):
return res
def try_find_model_base_from_png_metadata(
file: str, folder: str = "models"
) -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(
get_custom_model_pathfile(file + ".safetensors", folder)
).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in predefined_models:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
@@ -74,40 +150,21 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
png_hf_model_id = ""
png_custom_model = ""
if "Model" in metadata:
# Remove extension from model info
if metadata["Model"].endswith(".safetensors") or metadata[
"Model"
].endswith(".ckpt"):
metadata["Model"] = Path(metadata["Model"]).stem
# Check for the model name match with one of the local ckpt or safetensors files
if Path(
get_custom_model_pathfile(metadata["Model"] + ".ckpt")
).is_file():
png_custom_model = metadata["Model"] + ".ckpt"
if Path(
get_custom_model_pathfile(metadata["Model"] + ".safetensors")
).is_file():
png_custom_model = metadata["Model"] + ".safetensors"
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if metadata["Model"] in predefined_models:
png_custom_model = metadata["Model"]
# If nothing had matched, check vendor/hf_model_id
if not png_custom_model and metadata["Model"].count("/"):
png_hf_model_id = metadata["Model"]
# No matching model was found
if not png_custom_model and not png_hf_model_id:
print(
"Import PNG info: Unable to find a matching model for %s"
% metadata["Model"]
)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
@@ -115,12 +172,24 @@ def import_png_metadata(
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
hf_model_id = ""
if "Model" in metadata and png_hf_model_id:
custom_model = "None"
hf_model_id = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
@@ -149,4 +218,7 @@ def import_png_metadata(
height,
custom_model,
hf_model_id,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -0,0 +1,88 @@
ARG IMAGE_NAME
FROM ${IMAGE_NAME}:12.2.0-runtime-ubuntu22.04 as base
ENV NV_CUDA_LIB_VERSION "12.2.0-1"
FROM base as base-amd64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
ENV NV_NVPROF_VERSION 12.2.60-1
ENV NV_NVPROF_DEV_PACKAGE cuda-nvprof-12-2=${NV_NVPROF_VERSION}
FROM base as base-arm64
ENV NV_CUDA_CUDART_DEV_VERSION 12.2.53-1
ENV NV_NVML_DEV_VERSION 12.2.81-1
ENV NV_LIBCUSPARSE_DEV_VERSION 12.1.1.53-1
ENV NV_LIBNPP_DEV_VERSION 12.1.1.14-1
ENV NV_LIBNPP_DEV_PACKAGE libnpp-dev-12-2=${NV_LIBNPP_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev-12-2
ENV NV_LIBCUBLAS_DEV_VERSION 12.2.1.16-1
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_CUDA_NSIGHT_COMPUTE_VERSION 12.2.0-1
ENV NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE cuda-nsight-compute-12-2=${NV_CUDA_NSIGHT_COMPUTE_VERSION}
FROM base-${TARGETARCH}
ARG TARGETARCH
LABEL maintainer "SHARK<stdin@nod.com>"
# Register the ROCM package repository, and install rocm-dev package
ARG ROCM_VERSION=5.6
ARG AMDGPU_VERSION=5.6
ARG APT_PREF
RUN echo "$APT_PREF" > /etc/apt/preferences.d/rocm-pin-600
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ca-certificates curl libnuma-dev gnupg \
&& curl -sL https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - \
&& printf "deb [arch=amd64] https://repo.radeon.com/rocm/apt/$ROCM_VERSION/ jammy main" | tee /etc/apt/sources.list.d/rocm.list \
&& printf "deb [arch=amd64] https://repo.radeon.com/amdgpu/$AMDGPU_VERSION/ubuntu jammy main" | tee /etc/apt/sources.list.d/amdgpu.list \
&& apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
sudo \
libelf1 \
kmod \
file \
python3 \
python3-pip \
rocm-dev \
rocm-libs \
rocm-hip-libraries \
build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN groupadd -g 109 render
RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cudart-dev-12-2=${NV_CUDA_CUDART_DEV_VERSION} \
cuda-command-line-tools-12-2=${NV_CUDA_LIB_VERSION} \
cuda-minimal-build-12-2=${NV_CUDA_LIB_VERSION} \
cuda-libraries-dev-12-2=${NV_CUDA_LIB_VERSION} \
cuda-nvml-dev-12-2=${NV_NVML_DEV_VERSION} \
${NV_NVPROF_DEV_PACKAGE} \
${NV_LIBNPP_DEV_PACKAGE} \
libcusparse-dev-12-2=${NV_LIBCUSPARSE_DEV_VERSION} \
${NV_LIBCUBLAS_DEV_PACKAGE} \
${NV_CUDA_NSIGHT_COMPUTE_DEV_PACKAGE} \
&& rm -rf /var/lib/apt/lists/*
RUN apt install rocm-hip-libraries
# Keep apt from auto upgrading the cublas and nccl packages. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}
ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs

View File

@@ -0,0 +1,41 @@
On your host install your Nvidia or AMD gpu drivers.
**HOST Setup**
*Ubuntu 23.04 Nvidia*
```
sudo ubuntu-drivers install
```
Install [docker](https://docs.docker.com/engine/install/ubuntu/) and the post-install to run as a [user](https://docs.docker.com/engine/install/linux-postinstall/)
Install Nvidia [Container and register it](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). In Ubuntu 23.04 systems follow [this](https://github.com/NVIDIA/nvidia-container-toolkit/issues/72#issuecomment-1584574298)
Build docker with :
```
docker build . -f Dockerfile-ubuntu-22.04 -t shark/dev-22.04:5.6 --build-arg=ROCM_VERSION=5.6 --build-arg=AMDGPU_VERSION=5.6 --build-arg=APT_PREF="Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600" --build-arg=IMAGE_NAME=nvidia/cuda --build-arg=TARGETARCH=amd64
```
Run with:
*CPU*
```
docker run -it docker.io/shark/dev-22.04:5.6
```
*Nvidia GPU*
```
docker run --rm -it --gpus all docker.io/shark/dev-22.04:5.6
```
*AMD GPUs*
```
docker run --device /dev/kfd --device /dev/dri docker.io/shark/dev-22.04:5.6
```
More AMD instructions are [here](https://docs.amd.com/en/latest/deploy/docker.html)

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename):
def compare_images(new_filename, golden_filename, upload=False):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt":
if os.name != "nt" and upload == True:
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
"gs://shark_tank/testdata/builder/",
]
)
raise SystemExit("new and golden not close")
raise AssertionError("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,5 +1,6 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python build_tools/stable_diffusion_testing.py --gen
python tank/generate_sharktank.py

View File

@@ -63,7 +63,14 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(device="vulkan", beta=False, extra_flags=[]):
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
else:
print(command)
print("failed to generate image for this configuration")
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
test_loop(args.device, args.beta, [])
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()

View File

@@ -0,0 +1,14 @@
import os
from sys import executable
import subprocess
from apps.language_models.scripts import vicuna
def test_loop():
precisions = ["fp16", "int8", "int4"]
devices = ["cpu"]
for precision in precisions:
for device in devices:
model = vicuna.UnshardedVicuna(device=device, precision=precision)
model.compile()
del model

View File

@@ -24,7 +24,9 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
width=150,
height=100,
)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
@@ -37,7 +39,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
image = gr.Image(type="filepath", height=512)
with gr.Column(scale=1, min_width=600):
prompts = gr.Dropdown(

View File

@@ -56,3 +56,14 @@ for line in fileinput.input(path_to_lazy_loader, inplace=True):
)
else:
print(line, end="")
# For getting around timm's packaging.
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
path_to_timm_activations = Path(
get_python_lib() + "/timm/layers/activations_jit.py"
)
for line in fileinput.input(path_to_timm_activations, inplace=True):
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")
else:
print(line, end="")

View File

@@ -14,4 +14,5 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"
extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py"

View File

@@ -3,7 +3,7 @@
numpy>1.22.4
pytorch-triton
torchvision==0.16.0.dev20230322
torchvision
tabulate
tqdm
@@ -15,7 +15,7 @@ iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow>2.11
tf-nightly
keras
#tf-models-nightly
#tensorflow-text-nightly

View File

@@ -16,10 +16,12 @@ parameterized
# Add transformers, diffusers and scipy since it most commonly used
transformers
diffusers @ git+https://github.com/huggingface/diffusers@e47459c80f6f6a5a1c19d32c3fd74edf94f47aa2
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
ftfy
gradio==3.34.0
gradio
altair
omegaconf
safetensors
@@ -29,7 +31,14 @@ pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
py-cpuinfo
tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev

Some files were not shown because too many files have changed in this diff Show More