Compare commits

...

803 Commits

Author SHA1 Message Date
Ean Garvey
dea405285e Revert python version change and update actions/setup-python to v5 2024-08-08 16:50:15 -05:00
Ean Garvey
b1d2cb3bad Change python version in nightly .yml to 3.11.9 2024-08-08 16:48:41 -05:00
Ean Garvey
4759e808f2 Rest API support and cleanup 2024-08-08 11:37:53 -05:00
Ean Garvey
d5f37eaf20 Bump requirements 2024-06-17 18:16:44 -05:00
Ean Garvey
84bc1437a7 Strip sdxl-turbo options 2024-06-17 17:57:40 -05:00
Ean Garvey
83f424a83e Close advanced setting by default. 2024-06-06 13:22:57 -05:00
Ean Garvey
5b3b262359 Simplify ui further, add CLI option to load a default config 2024-06-06 13:21:25 -05:00
Ean Garvey
67b438eb9f take all ireert calls out of studio flow 2024-06-04 01:46:48 -05:00
Ean Garvey
4aa2d8b2a2 Purge shark/ directory, minimal ireert api usage for dynamically loaded plugins 2024-06-04 00:53:07 -05:00
Ean Garvey
dac7a29eef Purge unused code and patch out iree runtime handling from init 2024-06-03 18:00:05 -05:00
Ean Garvey
59600456be seed fixes 2024-06-02 16:25:16 -05:00
Ean Garvey
e514910202 Remove sdxl 30step config 2024-06-02 14:28:58 -05:00
Ean Garvey
33f6c312d9 limit steps to 2 (gives best results) 2024-06-02 14:25:29 -05:00
Ean Garvey
ab06047108 set a default 2024-06-02 14:23:40 -05:00
Ean Garvey
ac48b843e7 actually reduce steps 2024-06-02 14:00:45 -05:00
Ean Garvey
5f1b5e58d6 igpu dont error on device parse fail 2024-06-02 12:56:44 -05:00
Ean Garvey
6adae49d9b igpu restrictions 2024-06-02 12:51:53 -05:00
Ean Garvey
6abd9ff5cf Reduce available step options for turbo. 2024-06-02 11:41:23 -05:00
Ean Garvey
9957c96014 More noticeable seed changes 2024-06-02 11:39:21 -05:00
Ean Garvey
36b8c2fd6d disable pndm 2024-06-02 11:30:18 -05:00
Ean Garvey
9163c1fc50 small fixes 2024-06-02 11:28:37 -05:00
Ean Garvey
349e9f70fb Progress indicators 2024-06-02 10:18:09 -05:00
Ean Garvey
64e63e7130 znver4 device handling 2024-06-02 10:08:00 -05:00
Ean Garvey
ea8738fb1a Update SRT links 2024-06-02 09:50:09 -05:00
Ean Garvey
2a5bec3c4f Fixes for seed. 2024-06-02 09:46:22 -05:00
Ean Garvey
bb58b01d75 Switch to fixed steps, tweak config loading to prevent race condition 2024-06-01 20:15:53 -05:00
Ean Garvey
02285b33a4 More fixes for demo. 2024-06-01 19:46:52 -05:00
Ean Garvey
f9a1d35b59 Hide chatbot. 2024-06-01 14:24:37 -05:00
Ean Garvey
b1ca19a6e6 Cleanup for demo. 2024-06-01 13:42:51 -05:00
Ean Garvey
b5dea85808 Reduce UI for demos. 2024-06-01 12:00:22 -05:00
Ean Garvey
e75f96f2d7 fixup conditional 2024-06-01 12:00:11 -05:00
Ean Garvey
bf67e2aa3b Formatting 2024-06-01 11:59:10 -05:00
Ean Garvey
c088247aa1 Fix default configs, config loading, and add warnings/early returns for bad configs. 2024-06-01 11:58:51 -05:00
Ean Garvey
42abc6787d Small tweaks to ckpt processing, add tool to prefix params keys 2024-06-01 11:53:40 -05:00
Ean Garvey
26f80ccbbb Fixes to UI config defaults, config loading, and warnings. (#2153) 2024-05-31 18:14:27 -04:00
Ean Garvey
d2c3752dc7 Fix batch count and tweaks to chatbot. (#2151)
* Fix batch count

* Add button to unload models manually.

* Add compiled pipeline option

* Add brevitas to requirements

* Tweaks to chatbot

* Change script loading trigger
2024-05-31 18:48:28 +05:30
Ean Garvey
4505c4549f Force inlined weights on igpu for now, small fixes to chatbot (#2149)
* Add igpu and custom triple support.

* Small fixes to igpu, SDXL-turbo

* custom pipe loading

* formatting

* Remove old nodlogo import.
2024-05-30 11:40:42 -05:00
Gaurav Shukla
793495c9c6 [ui] Add AMD logo in shark studio
Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
2024-05-30 21:43:15 +05:30
Ean Garvey
13e1d8d98a Add igpu and custom triple support. (#2148) 2024-05-29 17:39:36 -05:00
Ean Garvey
2074df40ad Point to nod fork of diffusers. (#2146) 2024-05-29 00:56:21 -05:00
Ean Garvey
7b30582408 Point to SRT links for windows. (#2145) 2024-05-29 01:20:30 -04:00
Ean Garvey
151195ab74 Add a few requirements for ensured parity with turbine-models requirements. (#2142)
* Add scipy to requirements.

Adds diffusers req and a note for torchsde.
2024-05-28 15:37:31 -05:00
Ean Garvey
8146f0bd2f Remove leftover merge conflict line from setup script. (#2141) 2024-05-28 11:04:45 -07:00
Ean Garvey
68e9281778 (Studio2) Refactors SD pipeline to rely on turbine-models pipeline, fixes to LLM, gitignore (#2129)
* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Fixups to llm API/UI and ignore user config files.

* Small fixes for unifying pipelines.

* Update requirements.txt for iree-turbine (#2130)

* Fix Llama2 on CPU (#2133)

* Filesystem cleanup and custom model fixes (#2127)

* Fix some formatting issues

* Remove IREE pin (fixes exe issue) (#2126)

* Update find links for IREE packages (#2136)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Abstract out SD pipelines from Studio Webui (WIP)

* Switch from pin to minimum torch version and fix index url

* Fix device parsing.

* Fix linux setup

* Fix custom weights.

---------

Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
2024-05-28 13:18:31 -04:00
Ean Garvey
fd07cae991 Update find links for IREE packages (#2136) 2024-05-13 11:43:17 -05:00
gpetters94
6cb86a843e Remove IREE pin (fixes exe issue) (#2126)
* Diagnose a build issue

* Remove IREE pin

* Revert the build on pull request change
2024-04-30 12:27:30 -05:00
gpetters-amd
7db1612a5c Filesystem cleanup and custom model fixes (#2127)
* Initial filesystem cleanup

* More filesystem cleanup

* Fix some formatting issues

* Address comments
2024-04-30 11:18:33 -05:00
gpetters-amd
81d6e059ac Fix Llama2 on CPU (#2133) 2024-04-29 12:18:16 -05:00
saienduri
e003d0abe8 Update requirements.txt for iree-turbine (#2130)
* Update requirements.txt to iree-turbine creation

* Update requirements.txt

* Update requirements.txt

* Update requirements.txt
2024-04-29 12:28:14 -04:00
Quinn Dawkins
cf2513e7b1 Update IREE discord link (#2118)
Discord links for IREE were purged, so update the link on the readme.
2024-04-15 12:54:27 -07:00
Ean Garvey
60d8591e95 Change shark-turbine requirement target branch to main. (#2116) 2024-04-11 19:31:39 -04:00
gpetters-amd
ff91982168 Remove target env (#2114) 2024-04-08 16:52:45 -05:00
powderluv
a6a9e524c1 Drop linux nightly for now 2024-04-05 12:04:36 -07:00
powderluv
732df2e263 Updated signtool key 2024-04-05 12:01:42 -07:00
gpetters-amd
1ee16bd256 Fix the nightly build (#2111) 2024-04-05 19:22:33 +05:30
gpetters-amd
752d775fbd Fix a typo in the nightly build script (#2110) 2024-03-30 17:31:51 -07:00
gpetters-amd
4d1a6a204d Fix builder issue (#2109) 2024-03-30 16:21:55 -07:00
Ean Garvey
0eff62a468 (Studio 2.0) add Stable Diffusion features (#2037)
* (WIP): Studio2 app infra and SD API

UI/app structure and utility implementation.

- Initializers for webui/API launch
- Schedulers file for SD scheduling utilities
- Additions to API-level utilities
- Added embeddings module for LoRA, Lycoris, yada yada
- Added image_processing module for resamplers, resize tools,
  transforms, and any image annotation (PNG metadata)
- shared_cmd_opts module -- sorry, this is stable_args.py. It lives on.
  We still want to have some global control over the app exclusively
  from the command-line. At least we will be free from shark_args.
- Moving around some utility pieces.
- Try to make api+webui concurrency possible in index.py
- SD UI -- this is just img2imgUI but hopefully a little better.
- UI utilities for your nod logos and your gradio temps.

Enable UI / bugfixes / tweaks

* Studio2/SD: Use more correct LoRA alpha calculation (#2034)

* Updates ProcessLoRA to use both embedded LoRA alpha, and lora_strength
optional parameter (default 1.0) when applying LoRA weights.
* Updates ProcessLoRA to cover more dim cases.
* This bring ProcessLoRA into line with PR #2015 against Studio1

* Studio2: Remove duplications from api/utils.py (#2035)

* Remove duplicate os import
* Remove duplicate parse_seed_input function

Migrating to JSON requests in SD UI

More UI and app flow improvements, logging, shared device cache

Model loading

Complete SD pipeline.

Tweaks to VAE, pipeline states

Pipeline tweaks, add cmd_opts parsing to sd api

* Add test for SD

* Small cleanup

* Shark2/SD/UI: Respect ckpt_dir, share and server_port args (#2070)

* Takes whether to generate a gradio live link from the existing --share command
line parameter, rather than hardcoding as True.
* Takes server port from existing --server_port command line parameter, rather than
hardcoding as 11911.
* Default --ckpt_dir parameter to '../models'
* Use --ckpt_dir rather than hardcoding ../models as the base directory for
checkpoints, vae, and lora, etc
* Add a 'checkpoints' directory below --ckpt_dir to match ComfyUI folder structure.
Read custom_weights choices from there, and/or subfolders below there matching
the selected base model.
* Fix --ckpt_dir possibly not working correctly when an absolute rather than relative path
is specified.
* Relabel "Custom Weights" to "Custom Weights Checkpoint" in the UI

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)

* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.

* Add rest API endpoint from LanguageModel API

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* Formatting and init files.

* Remove unused import.

* Small fixes

* Studio2/SD/UI: Improve various parts of the UI for Stable Diffusion (#2074)

* Studio2/SD/UI: Improve various parts of the UI of Shark 2

* Update Gradio pin to 4.15.0.
* Port workarounds for Gradio >4.8.0 main container sizing from Shark 1.0.
* Move nod Logo out of the SD tab and onto the top right of the main tab bar.
* Set nod logo icon as the favicon (as current Shark 1.0).
* Create a tabbed right hand panel within the SD UI sized to the viewport height.
* Make Input Image tab 1 in the right hand panel.
* Make output images, generation log, and  generation buttons, tab 2 in the
right hand panel
* Make config JSON display, with config load, save and clear, tab 3 in the
right hand panel
* Make gallery  area of the Output tab take up all vertical space the other controls
on the tab do not.
* Tidy up the controls on the Config tab somewhat.

* Studio2/SD/UI: Reorganise inputs on Left Panel of SD tab

* Rename previously added Right Panel Output tab to 'Generate'.
* Move Batch Count, Batch Size, and Repeatable Seeds, off of Left Panel and onto 'Generate' Tab.
* On 'Generate' tab, rename 'Generate Image(s)' button to 'Start', and 'Stop Batch' button to 'Stop'. They are now below the Batch inputs on a Generate tab so don't need the specificity.
* Move Device, Low VRAM, and Precision inputs into their own 'Device Settings' Accordion control. (starts closed)
* Rename 'Custom Weights Checkpoint' to 'Checkpoint Weights'
* Move Checkpoint Weights, VAE Model, Standalone Lora Weights, and Embeddings Options controls, into their own 'Model Weights' Accordion control.  (starts closed)
* Move Denoising Strength, and Resample Type controls into their own 'Input Image Processing' Accordion. (starts closed)
* Move any remaining controls in the 'Advanced Options' Accorion directly onto the left panel, and remove then Accordion.
* Enable the copy button for all text boxes on the SD tab.
* Add emoji/unicode glphs to all top level controls and Accordions on the SD Left Panel.
* Start with the 'Generate' as the initially selected tab in the SD Right Panel, working around Gradio issue #7805
* Tweaks to SD Right Tab Panel vertical height.

* Studio2/SD/UI: Sizing tweaks for Right Panel, and >1920 width

* Set height of right panel using vmin rather than vh, with explicit affordances
for fixed areas above and below.
* Port >1920 width Gradio >4.8 CSS workaround from Shark 1.0.

* Studio2/SD: Fix sd pipeline up to "Windows not supported" (#2082)

* Studio2/SD: Fix sd pipeline up to "Windows not supported"

A number of fixes to the SD pipeline as run from the UI, up until the point that dynamo
complains "Windows not yet supported for torch.compile".

* Remove separate install of iree-runtime and iree-compile in setup_venv.ps1, and rely on the
versions installed via the Turbine requirements.txt. Fixes #2063 for me.
* Replace any "None" strings with python None when pulling the config in the UI.
* Add 'hf_auth_token' param to api StableDiffusion class, defaulting to None, and then pass
that in to the various Models where it is required and wasn't already being done before.
* Fix clip custom_weight_params being passed to export_clip_model as "external_weight_file"
rather than "external_weights"
* Don't pass non-existing "custom_vae" parameter to the Turbine Vae Model, instead
pass custom_vae as the "hf_model_id" if it is set. (this may be wrong in the custom vae
cast, but stops the code *always* breaking).

* Studio2/SD/UI: Improve UI config None handling

* When populating the UI from a JSON Config set controls to "None" for null/None
values.
* When generating a JSON Config from the UI set props to null/None for controls
set to "None".
* Use null rather string 'None' in the default config

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>

* Studio2/SD/UI: Further sd ui pipeline fixes (#2091)

On Windows, this gets us all the way failing in iree compile of the with SD 2.1 base.

- Fix merge errors with sd right pane config UI tab.
- Remove non-requirement.txt install/build of torch/mlir/iree/SRT in setup_venv.ps1, fixing "torch.compile not supported on Windows" error.
- Fix gradio deprecation warning for `root=` FileExplorer kwarg.
- Comment out `precision` and `max_length` kwargs being passed to unet, as not yet supported on main Turbine branch. Avoids keyword argument error.

* Tweak compile-time flags for SD submodels.

* Small fixes to sd, pin mpmath

* Add pyinstaller spec and imports script.

* Fix the .exe (#2101)

* Fix _IREE_TARGET_MAP (#2103) (#2108)

- Change target passed to iree for vulkan from 'vulkan'
to 'vulkan-spriv', as 'vulkan' is not a valid value for
--iree-hal-target-backends with the current iree compiler.

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>

* Cleanup sd model map.

* Update dependencies.

* Studio2/SD/UI: Update gradio to 4.19.2 (sd-studio2) (#2097)

- Move pin for gradio from 4.15 -> 4.19.2 on the sd-studio2 branch

* fix formatting and disable explicit vulkan env settings.

---------

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
Co-authored-by: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
2024-03-29 18:13:21 -04:00
dependabot[bot]
5a5de545c9 Bump gradio from 3.34.0 to 4.19.2 in /dataset (#2093)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.34.0 to 4.19.2.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.34.0...gradio@4.19.2)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2024-03-28 10:01:26 -05:00
Stefan Kapusniak
58f194a450 Fix _IREE_TARGET_MAP (#2103)
- Change target passed to iree for vulkan from 'vulkan'
to 'vulkan-spriv', as 'vulkan' is not a valid value for
--iree-hal-target-backends with the current iree compiler.
2024-03-18 00:21:44 -05:00
Stefan Kapusniak
c5cf005292 Add *.safetensors to .gitignore (#2089)
- shark2 is putting base model .safetensors file in model specific subfolders. Easiest to just ignore
.safetensors completely.
2024-02-17 21:37:47 -06:00
Stefan Kapusniak
12094ec49c Update README to recommend SHARK-1.0 for now (#2087)
- Add a note at the of the top of the README to use SHARK-1.0 whilst
rewrite with Turbine is ongoing.
- Update installation section to point at SHARK-1.0 branch.
- Suggest the latest SHARK-1.0 pre-release as well as stable.
- Recommend running the .exe from the command line.
2024-02-08 14:22:27 -06:00
Daniel Garvey
100e5b8244 address refactor in turbine (#2086)
python/turbine_models -> models
shark-turbine -> core
2024-02-05 13:05:01 -08:00
Stanley Winata
6bf51f1f1d HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)
* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.
2024-02-01 11:46:22 -06:00
Ean Garvey
05b498267e Add StreamingLLM support to studio2 chat (#2060)
* Streaming LLM 

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.
2024-01-18 19:01:07 -06:00
Ean Garvey
fa95ed30d1 Relocate quantized matmul reassociation flag (#2047)
* Remove quantized matmul reassociation flag

This flag should be a model/use-case specific addition, not a default CPU compile flag.
2023-12-20 12:48:40 -08:00
Daniel Garvey
788cc9157c Remove SHARK 1.0 implementations (#2042)
Any reimplementation of these features should be tracked in https://github.com/nod-ai/SHARK/issues/1931.
These implementations are preserved in the SHARK-1.0 branch: https://github.com/nod-ai/SHARK/tree/SHARK-1.0
2023-12-19 11:47:18 -06:00
Daniel Garvey
ebfcfec338 remove shark 1.0 tests, add support for 2.0 llm
* add support for external weights

* add tests and edit deps
2023-12-14 21:44:37 -06:00
Stefan Kapusniak
f692a012e1 UI: Fixes for Gradio 4.7.1/4.8.0 update (#2024)
* Upgrade Gradio pin from 4.7.1 to 4.80.
* Make Nod AI logos visible again.
* Remove image toolbars from png import boxes.
* Set Input Images on img2img, outpaint and upscaler tabs to be upload
only.
* Change Image control to an ImageEditor control for masking on the
inpaint tab. Remove previous height restriction as this hides the
editing controls.
* Move Input Image/Masked Image on img2img, inpaint, outpaint and
upscaler tabs to be the first control on their tabs.
* Remove download buttons from all galleries as they download some
html rather the image (gradio issue #6595)
* Remove add new row and column from Output Gallery parameters
dataframe.
* Add partial workaround for not being able to select text in the Output
Gallery Gallery parameters dataframe (gradio issue #6086 )
* Fix uglified formatting of subdirectory selection dropown, refresh
button, and open folder buttons on the Output Gallery tab.
* Force Output Gallery to use the full width of the Gallery control
for the preview overlay when an image is selected, rather than
an overlay the width of the selected image.
* Fix sendto buttons.
* Reset Inpaint ImageEditor control with the Mask Layer after generation
is complete, as it gets lost if the image was sent to the tab from
another tab rather than being uploaded. Also rework queuing and
progress rendering along this codepath. This doesn't solve the
underlying problem of the Mask Layer being removed, but does get inpaint
fully working with the Gradio update.
2023-12-14 14:56:37 -06:00
Vivek Khandelwal
3cc643b2de Add support for StableLM-3B model (#2019)
* Add support for StableLM-3B model

* Add support for Quantized StableLM-3B model

* Update stablelm_pipeline.py
2023-12-12 22:39:50 +05:30
Phaneesh Barwaria
bf70e80d20 vulkan device id fix (#2028) 2023-12-08 19:00:26 -06:00
Ean Garvey
7159698496 (Studio) Fix controlnet switching. (#2026)
* Fix controlnet switching.

* Fix txt2img + control adapters
2023-12-07 00:52:36 -06:00
gpetters94
7e12d1782a Fix stencil pipline to use input image (#2027) 2023-12-07 00:25:18 -06:00
Ean Garvey
bb5f133e1c Many UI fixes and controlnet impovements (#2025)
* multi-controlnet UI and perf fixes

* Controlnet fixes
2023-12-06 20:10:06 -06:00
Richard Pastirčák
3af0c6c658 #1843 - Add Export Default settings button (#2016)
* #1843 - Add Export Default settings button

* #1843 reformating units test

---------

Co-authored-by: Richard Pastirčák <richard.pastircak@student.tuke.sk>
2023-12-06 14:58:17 -06:00
Ean Garvey
3322b7264f (vicuna.py) Move enable_tracy_tracing outside of BenchmarkRunInfo (#2011) 2023-12-06 14:57:32 -06:00
Ean Garvey
eeb7bdd143 Fix nodlogo (#2023) 2023-12-06 14:57:16 -06:00
Ean Garvey
2d6f48821d Fix SharkEulerDiscrete (#2022) 2023-12-06 12:25:06 -06:00
Gaurav Shukla
c74b55f24e [ui] Add UI for sharding
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-12-06 17:25:49 +05:30
Elias Joseph
1a723645fb finilized fixes for sharded llama2 2023-12-06 15:35:29 +05:30
Eliasj42
dfdd3b1f78 improved sharded performance and fixed issue with lmhead on rocm (#2008)
* improved sharded performance and fixed issue with lmhead on rocm

* mmap shards + disable sharing of device arrays across devices

* fix device_idx for non-layer vmfbs

* fix time calc for sharded

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-12-05 11:53:44 -08:00
Ean Garvey
6384780d16 Fixes to llama2 cpu compilation and studio UI, schedulers (#2013)
* Fix some issues with defaults

Fixes to llama2 cpu compilation (turns off data tiling for old argmax
mode)

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
2023-12-05 11:19:19 -05:00
gpetters94
db0c53ae59 Fix zoedepth (#2010) 2023-12-05 04:31:50 -05:00
Ean Garvey
ce9ce3a7c8 (SD) Fix schedulers and multi-controlnet. (#2006)
* (SD) Fixes schedulers if recieving noise preds as numpy arrays

* Fix schedulers and stencil name

* Multicontrolnet fixes
2023-12-05 03:29:18 -06:00
Ean Garvey
d72da3801f (Studio) Update gradio and multicontrolnet UI. (#2001)
* (Studio) Update gradio and multicontrolnet UI.

* Fixes for outputgallery, exe build

* Fix image return types.

* Update Gradio to 4.7.1

* Fix send buttons and hiresfix

* Various bugfixes and SDXL additions.

* More UI fixes and txt2img_sdxl presets.

*enable SDXL-Turbo and custom models, custom VAE for sdxl

* img2img ui tweaks
2023-12-04 12:37:51 -06:00
Eliasj42
9c50edc664 fixed functionality of sharded vicuna/llama2 (#1982)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-12-04 09:11:52 -08:00
Abhishek Varma
a1b7110550 [SDXL] Add SDXL pipeline to SHARK (#1941)
* [SDXL] Add SDXL pipeline to SHARK

-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* (SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-12-02 03:15:15 -06:00
gpetters94
ff15fd74f6 Add multicontrolnet (#1958) 2023-12-01 13:51:20 -06:00
gpetters94
552b2c3ee3 Add controlmode (#1957) 2023-12-01 13:04:47 -06:00
Ean Garvey
795fc33001 Update default compilation flags for data tiling. (#2000)
* Update default CPU compilation flags.

c5a6cdc8dd

52eb7e9b82

tweak CPU iree-compile flags to match upstream changes.

* Add an option for data tiling on SD models.
2023-11-30 17:05:37 -06:00
gpetters94
2910841fe6 Fix an importer issue on Linux (#1986) 2023-11-30 10:50:33 -06:00
Vivek Khandelwal
396a054856 Fix Sharded Falcon-180b 2023-11-30 21:51:57 +05:30
Vivek Khandelwal
5c66948d4f Fix unsharded Falcon pipeline 2023-11-30 21:51:57 +05:30
Ean Garvey
ed3dda94c0 Cleanup xfails in pytest suite. (#1995) 2023-11-29 23:16:15 -06:00
Quinn Dawkins
d31d28b082 [SD] Add flag to collapse reduction dims pre dispatch formation (#1999) 2023-11-30 00:09:17 -05:00
Evan Ruttenberg
78c607e1d3 Fix typo in default_rocm_arch (#1998) 2023-11-29 20:40:56 -05:00
Vivek Khandelwal
666e601dd9 Remove sharding support for non-180B falcon variants 2023-11-27 13:45:13 +05:30
Vivek Khandelwal
ca58908e5b Add Falcon-GPTQ Support for 2-way sharding 2023-11-27 13:45:13 +05:30
Jakub Kuderski
1f5b39f56e [vicuna.py] Add option to enable tracing (#1993)
This makes the program wait for tracy profiler to connect before exiting
and flush profiling data after each token.

I don't know how to select the tracy iree-runtime variant
programatically -- instead, print an error and exit.
2023-11-24 12:25:03 -08:00
Jakub Kuderski
2da31c4109 [vicuna.py] Rework benchmark statistics calculation (#1992)
- Move statistics out of the main loop
- Add 'end-to-end' numbers
- Switch the main display unit from s to ms
- Start measuring time at 0

The new print format looks like this:
```
Number of iterations: 5
Num tokens: 1 (prompt), 512 (generated), 513 (total)
Prefill: avg. 0.01 ms (stdev 0.00), avg. 97.99 tokens/s
Decode: avg. 4840.44 ms (stdev 28.80), avg. 97.99 tokens/s
Decode end-2-end: avg. 85.78 tokens/s (w/o prompt), avg. 95.98 (w/ prompt)
```
2023-11-23 12:04:03 -05:00
Ean Garvey
da50a16242 Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989) 2023-11-23 01:01:41 -06:00
Stefan Kapusniak
ce38d49f05 Add .mlir to startup shark_tmp cleanup (#1991)
* Add .mlir to the fiiles that are deleted from `./shark_tmp` when studio
is started.
* refactor/rename existing gradio temp file cleanup on startup to be
consistent with a general `./shark_tmp` cleanup
2023-11-22 14:34:28 -06:00
PhaneeshB
2f780f0d38 quick fix rocm None device 2023-11-22 21:17:25 +05:30
Ean Garvey
d051c3a4a7 Use clean_device_info() by default and don't write .mlir to /tmp/ (#1984)
* Move clean_device_info to compile_utils

* Update compile_utils.py

* Fix .mlir writes for some user-level permissions

* Fix cases where full URI is given

* Fix conditionals.

* Fix device path handling in vulkan utils.
2023-11-20 13:10:31 -06:00
Ean Garvey
1b11c82c9d Small UI tweaks for chatbot, fix torchvision requirements (#1988)
- add torchvision to setup_venv.ps1 -- we need this for the torchvision::nms that is now a dependency of controlnet features.
- Don't have bad flashy orange updates when using the chatbot
- Don't limit the height of the chatbot -- there's mixed opinions and solutions around this one. I think the default (400) is just way too small and LLMs generate plenty enough to justify matching the output.
2023-11-21 00:09:10 +05:30
gpetters94
80a33d427f Save intermediate values of controlnet (#1981) 2023-11-17 19:05:41 -05:00
Stefan Kapusniak
4125a26294 API/Docs: Fix incorrect cors arguments listing (#1983)
* Replace `api_cors_origin` in the api/koboldcpp doc, with the correct
 `api_accept_origin`
2023-11-17 12:29:01 -06:00
Ean Garvey
905d0103ff Revert "Re-enable SD tunings without matmuls. (#1976)" (#1979)
This reverts commit 70817bb50a.
2023-11-17 23:44:33 +05:30
Stefan Kapusniak
192b3b2c61 UI: Output galllery cleanups (#1959)
* Workaround gradio bug that causes the parameters frame to always show
scrollbars.
* Remove the original funky method of setting the number of image
columns in the gallery using _fn= javacript events. The version
of gradio we now have pinned allows doing this by setting the property
on the gallery directly and also doesn't keep resetting the columns on
other events being fired.
2023-11-15 22:20:42 -06:00
Stefan Kapusniak
8f9adc4a2a UI: Display top tag frequencies for selected LoRA (#1972)
* Adds a function to webui utils to read metadata from
.safetensors LoRA files. and do limiting parsing of the format written
out by the Kohya SS scripts (https://github.com/kohya-ss/sd-scripts)
to get tag frequency and trained model information.
* Adds a new common_ui_events.py file for gradio event handlers
needed for multiple UI tabs, and adds an event handler for binding to
the change event of the LoRA selection boxes, that outputs HTML
to display the LoRA tag frequency and model information.
* Adds an HTML gradio control to each of the SD tabs to show the
LoRA model name, and most frequently trained tags.
* Bind the change event of the LoRA selection box on each tab
to our new event handler, with the output set to the relevant HTML
control.
2023-11-15 22:19:54 -06:00
Ean Garvey
70817bb50a Re-enable SD tunings without matmuls. (#1976) 2023-11-15 20:42:53 -06:00
jinchen62
dd37c26d36 Update brevitas quant api (#1975) 2023-11-15 10:04:07 -08:00
PhaneeshB
a708879c6c fix iree version mismatch 2023-11-15 01:24:42 +05:30
Ean Garvey
bb1b49eb6f Add --no-index to setup_venv.sh runtime pip install. 2023-11-14 21:44:20 +05:30
Ean Garvey
f6d41affd9 (SHARK Studio) Add Turbine-based llm chatbot. (#1933)
* Dan shark studio (#1970)

* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>

* Fix formatting.

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-14 09:56:28 -06:00
Stefan Kapusniak
c2163488d8 SD/UI Restrict hires fix/img2img resamplers/schedulers (#1955)
* Restrict resamplers for img2img and high res fix to the ones that
PIL.Image actually supports, since it uses that to di the resampling.
Removed: Antialias, Affine, Cubic. Added: Hamming.
* Set list of available schedulers to CPU only when high res fix
is selected in the web ui. Set list to all schdulers when high res fix
is deselected.
* Put hi res fix in its own Accordian in the txt2img UI instead of
grouping it with Advanced Options.
2023-11-13 16:08:24 -06:00
PhaneeshB
54bff4611d fix cli rocm device selection 2023-11-13 23:35:55 +05:30
PhaneeshB
11510d5111 add intra rocm vmfb differentiator 2023-11-13 23:35:55 +05:30
PhaneeshB
32cab73a29 add iree-rocm-target-chip only if added by user 2023-11-13 23:35:55 +05:30
PhaneeshB
392bade0bf enable non default rocm device selection for webui 2023-11-13 23:35:55 +05:30
Stefan Kapusniak
91df5f0613 API/Docs: Fix an image link in koboldcpp doc (#1954)
* Fix the image link for the koboldcpp style button pointing to the
dialog image rather than the button image.
2023-11-13 11:14:29 -06:00
dependabot[bot]
df20cf9c8a Bump langchain in /apps/language_models/langchain (#1968)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.325 to 0.0.329.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.325...v0.0.329)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-12 19:46:00 -08:00
Ean Garvey
c4a908c3ea Pin pydantic to 2.4.1 in requirements (#1967)
pyinstaller-hooks-contrib doesn't see beta versions of pydantic as versions greater than 2.0.0, and so it looks for an attribute `compile` only available in versions older than 2.0.0 if you have a beta version of pydantic.
2023-11-10 21:34:52 -06:00
Stefan Kapusniak
6285430d8a UI: Fix webui launch on non-Windows (#1963)
* Moves the imports of winreg and Tk, into the functions that use them,
with winreg behind a guard clause. This should hopefully mean that if
you're not on Window or not using `ui=app` we won't trip over either
of these due to them not being there.
2023-11-10 16:38:32 -06:00
PhaneeshB
51afe19e20 fix rocm arch selection 2023-11-10 13:22:51 +05:30
Ean Garvey
31005bcf73 Don't require vulkan installation to query devices. (#1953) 2023-11-09 14:46:44 -06:00
dependabot[bot]
f41ad87ef6 Bump langchain in /apps/language_models/langchain (#1926)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.202 to 0.0.325.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.202...v0.0.325)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:03:47 -06:00
dependabot[bot]
d811524a00 Bump pypdf from 3.12.2 to 3.17.0 in /apps/language_models/langchain (#1929)
Bumps [pypdf](https://github.com/py-pdf/pypdf) from 3.12.2 to 3.17.0.
- [Release notes](https://github.com/py-pdf/pypdf/releases)
- [Changelog](https://github.com/py-pdf/pypdf/blob/main/CHANGELOG.md)
- [Commits](https://github.com/py-pdf/pypdf/compare/3.12.2...3.17.0)

---
updated-dependencies:
- dependency-name: pypdf
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-09 11:02:43 -06:00
Sungsoon Cho
51e1bd1c5d (OPT) Fix typo in the message; s/reponse/response (#1920) 2023-11-09 11:00:48 -06:00
Phaneesh Barwaria
db89b1bdc1 Fix MacOS web execution flow (#1899)
* fix metal device path for chatbot

* single device remove indexing

* lint fix
2023-11-09 10:59:29 -06:00
Huang Qi
2754e2e257 Fix wrong parameter index passed to 'compile_module_to_flatbuffer' (#1921)
compile_str is always False in compile_module_to_flatbuffer since there
is a parameter 'model_name' before 'debug'.

This issue is relative to https://github.com/nod-ai/SHARK/pull/1863.

Then we can use mlir model buffer in RAM to run inference.
2023-11-09 10:58:05 -06:00
PhaneeshB
ab0e870c43 fix vicuna cli vulkan 2023-11-09 22:27:13 +05:30
Stefan Kapusniak
fb30e8c226 UI: Fix some webui launch corner cases (#1952)
* On windows insist on the presence of webview2 as the embeddable
browser for `ui=app`. If we can't find it, effectively switch back to
`ui=web`. This should prevent pywebview trying to use MSHTML, whilst
saying its deprecated, and apparently we are too much for poor old IE11
* Add webview2 runtime droppings to .gitignore.
* If we can't bind to args.server_port get another suitable port from
the OS and advise the user that we did this in the UI.
* Make `ui=web` mode use 'SHARK AI Studio' as its title. This makes it
consistent with `ui=app`.
* Replace the generic gradio favicon with a nod swirl one instead.
2023-11-09 10:53:28 -06:00
Ean Garvey
a07d542400 (Studio) Disable SD tunings and sub-model downloads (#1944)
* sets --no-use_tuned and --import_mlir as defaults in SHARK Studio.
2023-11-07 15:55:30 -06:00
Stefan Kapusniak
ad55cb696f SD/API: Add missing A1111 APIs to Shark to support koboldcpp image generation (#1924)
* SD/API: Add missing a1111 API features for Koboldcpp

* Refactors SD api functions into their own file
* Adds the following apis implemented by a1111 as needed by koboldcpp:
   - adds /sdapi/v1/sd-models (lists available models)
   - adds /sdapi/v1/options (only the bare minimum needed)
* Adds optional CORS support, use the '--api_accept_origin' command line
argument to activate and configure.
* Extends existing APIs to include optional sampler/scheduler selection
* Extends /sdapi/v1/textimg to recognise the method used by koboldcpp
to select the model.
* Where possible take values not provided to the API in the request from
the existing relevant command line parameters rather than hardcoding
them.
* return a 400 response when a request doesn't have required properties.
* changed default schedulers and models for some apis to ones that
actually seem to work.
* Update api_test.py to include the new APIs.
* Update api_test.py to include a '--verbose' command line option.

* SD/API: Take more API values from args

* Take LoRA from '--use_lora' command line arg if specified
* Take device from '--device' command line arg if specified (substring
match, so a short name such as 'vulkan://0' should work)

* SD/API: add more endpoints and pydantic typing

* Mount the whole of /sdapi from index.py as a FastAPI application,
rather than each endpoint individually
* Add the following additional API endpoints:
  * /sdapi/v1/samplers
  * /sdapi/v1/cmd-flags
* Make scheduler/sampler selection checking and fallback much more
robust.
* Support aliasing some A1111 scheduler/sampler names to the diffusers
ones we are using.
* Expand response /sdapi/v1/options to add a few more things.
* Split non-api functions and variables into their own utils.py file.
* Support 'n_iter' request property and the return of multiple images
from generation endpoints. Equivalent of '--batch_count', batch_size
is stil hardcoded at 1
* Include (some) hires_fix request properties in txt2img endpoint
* Rework endpoints using pydantic model classes for better request
validation and so we get much improved swagger api docs at
/sdapi/docs and redoc at /sdapi/redoc

* SD/API Delete commented out code from index.py

* Delete some code that is no longer needed by the SD API in index.py
(and one line sdapi_v1.py) that I'd previously only commented out.

* SD/UI: Add shark_sd_koboldcpp.md document

* Add documentation on how to set up Koboldcpp with SHARK
* Link this and the existing blender set up document from the main
README.md

* SD/API Improve stencil options in img2img endpoint

In /sdapi/v1/img2img:
  * Add zoedepth to the controlnet use_stencil options
  * Require and use second image as stencil mask for controlnet scribble
2023-11-06 15:20:19 -06:00
Jakub Kuderski
488a172292 [vicuna.py] Allow to pass extra arguments to iree-compile (#1935)
Add a new flag `-Xiree_compile` to forward extra compiler arguments to
`iree-compile`. This flag can be set multiple times to pass more than
one extra argument.
2023-11-06 12:12:34 -05:00
Stanley Winata
500c4f2306 [compile utils] Fix ROCM to not expect config.id as a default. (#1939) 2023-11-06 08:44:53 -08:00
Vivek Khandelwal
92b694db4d Add support for Falcon-40b-GPTQ 2023-11-06 19:49:19 +05:30
Vivek Khandelwal
322874f7f9 Fix issue in Falcon-GPTQ 2023-11-03 11:48:36 +05:30
Ean Garvey
5001db3415 Add 7800xt to target triples explicitly. (#1928) 2023-11-01 17:11:45 -05:00
Vivek Khandelwal
71846344a2 Add sharded Falcon-GPTQ support
This commit adds the support for sharded Falcon-7b-GPTQ and
Falcon-180B-GPTQ. This commit also adds the support for 4-way
sharding of the Falcon model for the device ROCM.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-11-01 12:11:44 +05:30
gpetters94
72e27c96fc Add ZoeDepth (#1834)
* Add ZoeDepth

* Add einops to Studio imports.

* Specify ref for forked torch.hub repos.

* Unpin timm.

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Ean Garvey <garveyej@gmail.com>
2023-10-30 11:57:45 -05:00
PhaneeshB
7963abb8ec remove caching for rocm args 2023-10-29 07:07:57 +05:30
Ean Garvey
98244232dd Add smoothquant OPT to examples. (#1922) 2023-10-27 12:32:12 -05:00
PhaneeshB
679a452139 fix calls and remove unused imports for check_device_drivers 2023-10-27 10:30:40 +05:30
PhaneeshB
72c0a8abc8 remove dependency on external commands for driver installation check 2023-10-27 10:30:40 +05:30
Vivek Khandelwal
ea920f2955 Add sharded Falcon support 2023-10-26 21:53:25 +05:30
Phaneesh Barwaria
486202377a update dependency on rocm/hip info command (#1900)
* add support for rocm flags

* add rocm target flag to chat args

* rm rocm libs dependency message
2023-10-26 15:18:25 +05:30
Sungsoon Cho
0c38c33d0a Add opt_causallm_samples.py. (#1916) 2023-10-25 11:52:51 -05:00
Ean Garvey
841773fa32 Updates to opt_causallm example (#1905)
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
2023-10-24 10:54:39 -07:00
Stefan Kapusniak
0361db46f9 SD: Fix unet untuned opt_flags (#1912)
* correct my sloppy copy/paste for the untuned unet default compilation
flags that introduced an extra 'detach' into what should have been
'iree-global-opt-convert-1x1-filter-conv2d-to-matmul'
2023-10-24 12:47:33 -05:00
xzuyn
a012433ffd Save hiresfix info if used (#1914) 2023-10-24 12:45:10 -05:00
xzuyn
5061193da3 Move Generate, Randomize Seed, & Stop Batch to same positions as txt2img (#1915) 2023-10-24 12:44:39 -05:00
xzuyn
bff48924be LLaMa 2 Chat template fix (#1913) 2023-10-23 18:51:15 -05:00
Stefan Kapusniak
825b36cbdd Fix MLIR Textual PassPipeline Error (#1910) 2023-10-22 07:39:52 -07:00
Stefan Kapusniak
134441957d SD - Fix civitai download on Windows +improvements (#1907) 2023-10-21 11:17:41 -07:00
Stefan Kapusniak
7cd14fdc47 SD/UI: Use a single model selection box on UI tabs (#1906)
* Allow entry of a huggingface model id or civitai download url to be
done in the main model selection dropdown on SD tabs
* Remove separate textbox for entering huggingface model id or civitai
download url on SD Tabs
* Remove 'None' option from the model selection dropdown (no longer
needed) on SD tabs
* Update png metadata drop zone on txt2img tab to work with a single
argument for model selection
* Update UI generate functions on SD tabs to work with single argument
model selection
* Update API code for changes to the UI generate functions
* Move info about the custom model path to the logging textarea on SD
tabs
2023-10-21 10:06:05 -07:00
Ean Garvey
e6cb5cef57 Add --additional_runtime_args option and use in OPT example. (#1855)
* Add --additional_runtime_args option and use in OPT example.

Fix the func name. (#1838)

Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
2023-10-19 13:29:39 -05:00
Huang Qi
66abee8e5b SharkInference: Fix various examples and README.md (#1903)
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
2023-10-19 09:28:36 -05:00
Ean Garvey
4797bb89f5 Stringify path for ireec.compile_file (#1901)
* Stringify path for ireec.compile_file

* Update test-models.yml
2023-10-18 14:59:23 -05:00
Vivek Khandelwal
205e57683a Modify Falcon-180b-GPTQ sharded pipeline 2023-10-17 20:26:01 +05:30
Vivek Khandelwal
2866d665ee Fix Sharded Falcon-180b-GPTQ Pipeline 2023-10-17 20:26:01 +05:30
Stefan Kapusniak
71d25ec5d8 SD: Fix repeatable seeds when intial seed is random (#1893) 2023-10-14 22:50:42 -07:00
Vivek Khandelwal
202ffff67b Add support for sharded Falcon model 2023-10-13 22:05:10 +05:30
Ean Garvey
0b77059628 Add matmul reassociation flags (#1891) 2023-10-12 20:12:37 -05:00
Stefan Kapusniak
a208302bb9 Fix repeatable seeds consistency over batch counts (#1889)
* Set the input seed for the random number generator when
generating repeatable seeds to exclude any negative numbers
in the parsed seed input.  The makes seeds generated for
different batch counts consistent where they have the same
input for the initial seed or set of seeds.
2023-10-12 17:15:19 -05:00
Vivek Khandelwal
b83d32fafe Fix Falcon GPTQ Pipeline 2023-10-11 20:09:32 +05:30
Vivek Khandelwal
0a618e1863 Add support for Falcon GPTQ 2023-10-11 10:47:48 +05:30
Phaneesh Barwaria
a731eb6ed4 Macos fixes (#1883)
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
2023-10-09 23:36:12 -07:00
Ean Garvey
2004d16945 Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882)
This reverts commit 9f0a421764.
2023-10-09 18:01:44 -07:00
Gaurav Shukla
6e409bfb77 fix else if syntax error
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 06:23:56 +05:30
Gaurav Shukla
77727d149c [warning] Fix dropdown warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-10 05:18:43 +05:30
Ean Garvey
66f6e79d68 Split CPU/GPU definitions conditionally outside of torch contexts. (#1879) 2023-10-09 16:46:41 -07:00
Ean Garvey
3b825579a7 (LLaMa-2) Point to int4 + f32 acc .mlir for cpu (#1878)
- fixes some issues with non-system prompt invocation

Co-authored-by: Gaurav Shukla <gauravshukla789@gmail.com>
2023-10-09 14:37:35 -05:00
Abhishek Varma
9f0a421764 [SDXL] Add SDXL pipeline to SHARK (#1731)
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-09 13:01:37 -05:00
Gaurav Shukla
c28682110c [chatbot] Flag to add system prompt
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-09 22:17:39 +05:30
Ean Garvey
caf6cc5d8f Switch most compile flows to use ireec.compile_file. (#1863)
* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
2023-10-06 23:04:43 -05:00
Ean Garvey
8614a18474 Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path.

* Fix formatting.
2023-10-06 12:27:12 -07:00
Jakub Kuderski
86c1c0c215 Add aggregate statistics to microbenchmark (#1871)
Print averaged results at the end of all iterations. Increase the
default number of iterations to 5.

Example:
```
Number of iterations: 5
Prefill: avg. 0.03 s, stddev 0.00
Decode: avg. 43.34 tokens/s, stdev 0.13
```

Also remove the -2 in the number of generated tokens -- I did not find
any evidence we need it.
2023-10-06 10:03:07 -07:00
Daniel Garvey
8bb364bcb8 enforce fp32 accumulates for cpu (#1873) 2023-10-06 11:34:49 -05:00
Daniel Garvey
7abddd01ec argmax inside model + brevitas pin (#1872) 2023-10-05 20:15:21 -07:00
Abhishek Varma
2a451fa0c7 [Llama2] Add a standalone utility for dynamic and combining IRs
-- This script adds a standalone utility for converting Llama IRs
   to dynamic and combining them as well.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-10-05 20:01:06 +05:30
Jakub Kuderski
9c4610b9da Add microbenchmark mode to vicuna CLI (#1864)
Add flags to enable a non-internactive mode for microbenchmarking llama
models. In this mode, the system and user prompts are specified with CLI
flags, and the number of generated tokens and iterations is fixed.

Also move the stats below the response and trim any response blankspace.
2023-10-05 00:12:08 -04:00
powderluv
a38cc9d216 Update vulkan_utils.py for Radeon 780m igpu (#1866) 2023-10-04 20:33:07 -07:00
Jakub Kuderski
1c382449ec [vulkan] Print note about module load times. NFC. (#1862)
Print a note ahead of a potentially long inactivity to set the right expectations.

Separately, we should add progress to the UI and make this loading faster.
2023-10-03 17:27:27 -04:00
Gaurav Shukla
7cc9b3f8e8 [llama cli] Fix llama cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-03 20:39:53 +05:30
Gaurav Shukla
e54517e967 [UI] Disable config generator, lora train and model manager (#1858)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-10-02 22:34:40 -07:00
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
Gaurav Shukla
e930ba85b4 [os] Remove os dependency from vmfb naming (#1854)
Also fixes a small ui issue for chatbot.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 12:38:17 -05:00
Gaurav Shukla
cd732e7a38 [chatbot] split execution time to prefill and decode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
8e0f8b3227 [ui] Update chatbot UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
Gaurav Shukla
b8210ef796 [chatbot] Re-instantiate the chatbot object if device id changes
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-29 13:18:03 +05:30
PhaneeshB
94594542a9 remove use of vulkaninfo 2023-09-28 21:57:00 +05:30
Gaurav Shukla
82f833e87d [vulkan] Update vmfb naming
Update vmfb naming for vulkan devices in order to resolve naming
conflicts in the presence of multiple vulkan devices.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-28 14:52:11 +05:30
Vivek Khandelwal
c9d6870105 Modify falcon pipeline for 180b support 2023-09-28 12:39:35 +05:30
Jakub Kuderski
4fec03a6cc [vulkan] Switch from coop matrix NV to KHR (#1848) 2023-09-27 21:43:37 -04:00
harsh-nod
9a27f51378 Deprecate inference directory
This patch removes the inference directory that was no longer being used.
2023-09-27 14:29:00 -07:00
Abhishek Varma
ad1a0f35ff Fix misdirection while saving vmfb
-- Currently SHARK suggests that vmfb has been saved, while
    that is not the case and no vmfb is generated. 
    This creates a misdirection for IR/vmfbs which are of larger
    size.
-- This commit therefore fixes that misdirection.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-27 16:25:29 +05:30
Nelson Sharpe
6773278ec2 Fix checkpoint_path unexpected argument (#1832) 2023-09-24 14:17:52 -07:00
Abhishek Varma
9a0efffcca [Llama2] Fix wrong Vulkan device ID + Add Vulkan compile flags
-- This commit fixes the wrong Vulkan device being selected during
   runtime.
-- It also adds couple of IREE compilation flags to target specific
   Vulkan device.
-- It also changes the Vulkan device listing to be more in tune with
   lowering control flow.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-22 22:24:18 +05:30
gpetters94
61c6f153d9 Switch to keras-nightly to fix a Linux issue (#1835) 2023-09-21 12:33:45 -04:00
Phaneesh Barwaria
effd42e8f5 pin gradio to v3.44.3 2023-09-21 17:33:43 +05:30
Sungsoon Cho
b5fbb1a8a0 Rename the func arg save_json to avoid name collision. (#1837)
* Rename the func arg save_json to avoid name collision.

* black formatted.
2023-09-19 17:29:27 -05:00
Quinn Dawkins
ded74d09cd [vicuna.py] Keep past key values on device (#1836)
The past key values are only used within the models themselves and can
be kept on device. For vulkan int4, this gives 44 tok/s (for the first
prompt) and settles at around 26 tok/s on 7900xtx.
2023-09-19 18:17:41 -04:00
Boian Petkantchin
79267931c1 Add argument --additional_compile_args (#1119)
This allows to pass more arguemnts to the IREE compiler
Example:
python my-app.py --additional_compile_args="--mlir-pretty-debuginfo --mlir-timing"

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-09-19 11:26:03 -05:00
zjgarvey
9eceba69b7 local_tank_cache included into clear_all (#1833) 2023-09-18 00:27:23 -05:00
Ean Garvey
ca609afb6a Update README.md (#1830) 2023-09-14 10:33:57 -05:00
Gaurav Shukla
11bdce9790 [flags] Fix vulkan runtime flags as vma is dropped from iree (#1831) 2023-09-14 08:58:59 -05:00
Ean Garvey
684943a4a6 (SD) Fix tokenizers imports in pyinstaller builds. (#1828)
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
2023-09-12 12:23:48 -05:00
PhaneeshB
b817bb8455 add roles for llama2 2023-09-12 10:59:28 +05:30
Ean Garvey
780f520f02 Fix vk.target_env extensions and remove redundant SD imports. (#1826)
* Remove redundant IREE runtime imports.

* Fix vulkan target env extensions.
2023-09-11 13:42:52 -05:00
Dom
c61b6f8d65 Code refactoring (#1817)
* use join

* fix bug

* further code optimizations

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
2023-09-11 11:30:56 -05:00
Abhishek Varma
c854208d49 [Llama2] Prefetch llama2 tokenizer configs (#1824)
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-08 11:29:54 -07:00
Gaurav Shukla
c5dcfc1f13 [vicuna] Exit when mlir is not present in shark tank (#1825)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-08 10:30:29 -07:00
Abhishek Varma
bde63ee8ae Add logging feature in WebUI (#1821) 2023-09-08 05:48:05 -07:00
Vivek Khandelwal
9681d494eb Update decomp list and shark trainer for DLRM 2023-09-06 21:24:50 +05:30
Gaurav Shukla
ede6bf83e2 [vicuna] Disabling the IR generation path
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-09-06 20:13:17 +05:30
Ean Garvey
2c2693fb7d Fix torchvision versioning in Linux importer setup. (#1809) 2023-09-05 12:57:03 -05:00
Vivek Khandelwal
1d31b2b2c6 Fix StableHLO Compilation flag 2023-09-05 21:32:33 +05:30
Gaurav Shukla
d2f64eefa3 [chatbot] Remove few outdated models from list (#1814) 2023-09-04 09:26:32 -07:00
Abhishek Varma
87ae14b6ff [SD] Add sdpfa decomposition + update IREE flag
-- This commit adds Scaled Dot Product Flash Attention's decomposition
   in shark_importer.
-- It also updates `iree-flow-enable-data-tiling` to `iree-opt-data-tiling`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-09-04 18:03:53 +05:30
Phaneesh Barwaria
1ccafa1fc1 fix llama2-70b rewrite tensor dim 2023-09-01 17:27:06 +05:30
jinchen62
4c3d8a0a7f Enable downloading vmfb/mlir for webui (#1807) 2023-08-31 11:05:47 -07:00
jinchen62
3601dc7c3b Fix llama2 13b combined ir (#1803) 2023-08-28 11:34:44 -07:00
Daniel Garvey
671881cf87 Llama2 70b (#1783)
* llama2 70b IR gen

* fix IR sec llama2 + debug

* llama270b

---------

Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-08-25 23:04:28 -07:00
Gaurav Shukla
4e9be6be59 [chatbot] Add debug as class attribute (#1799)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-25 21:46:29 -07:00
Ean Garvey
9c8cbaf498 Add support for ROCM (Windows) in Studio + compile utils (#1770)
* WIP: MSVC ROCM support for SHARK Studio

* Make get_iree_rocm_args platform-agnostic.

* Update stable_args.py

* Update rocm arg handling in SD utils

* Guard quantization imports.

Co-authored-by: jam https://github.com/jammm
2023-08-25 20:56:05 -07:00
Ean Garvey
9e348a114e Revert changes process_skipfiles.py (#1798)
Keeps a small typo fix but reverts the rest of changes to this file from 450c231171
2023-08-25 15:31:49 -07:00
jinchen62
51f90a4d56 Update conversion passes for brevitas quant op (#1795) 2023-08-25 17:28:07 -05:00
Abhishek Varma
310d5d0a49 Fix llama2 13b crashing + add spec file for CLI execution of Llama (#1797)
* [Llama2] Add a fix for Llama2 13B downloading/crashing

-- This commit fixes downloading/crashing of llama2 13B on wrong
   .mlir file.
-- Also adds support for downloading vmfb from shark_tank in CLI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [llama2] Add a spec file to run Llama/Vicuna CLI exe

-- This commit adds a spec file to run Llama/Vicuna CLI exe.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-25 09:36:09 -05:00
Ean Garvey
9697981004 Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
2023-08-25 07:11:11 -07:00
Ean Garvey
450c231171 Add tokenizers to requirements.txt (#1790)
* Add tokenizers to requirements and pin version

* Update process_skipfiles.py
2023-08-24 19:44:04 -05:00
Ean Garvey
07f6f4a2f7 Add a short README for the OPT examples and small tweaks. (#1793)
* Small changes to OPT example.

* Update opt README.

* Add a few modes to batch script.

* Update README.md
2023-08-24 17:26:11 -07:00
jinchen62
610813c72f Add iree flag to strip assertions (#1791) 2023-08-24 10:51:19 -07:00
Ean Garvey
8e3860c9e6 Remove flags that are default in upstream IREE (#1785)
* Remove index bits flags now set by default

* Update shark_studio_imports.py
2023-08-24 11:57:54 -05:00
xzuyn
e37d6720eb Add Hires Fix (#1787)
* improper test hiresfix

* add sliders & use `clear_cache`

* add resample choices & fix step adjustment

* add step adjustment to img2img

* add resample options to img2img

* simplify hiresfix
- import `img2img_inf` from `img2img_ui.py` instead of just copying it into `txt2img_ui.py`

* set `hri` to None after using

* add more resample types, and don't show output until hiresfix is done

* cleaner implementation

* ran black

* ran black again with jupyter dependencies
2023-08-24 09:01:41 -07:00
Vivek Khandelwal
16160d9a7d Fix combine mlir script 2023-08-24 19:10:49 +05:30
Sungsoon Cho
79075a1a07 Opt perf (#1786)
* Define command line args, model-name, max-seq-len, platform, etc.

* Add usage example.

* Add opt_perf_comparision_batch.py.

* Use shlex instead.
2023-08-24 08:33:12 -05:00
Abhishek Varma
db990826d3 Add Llama2 13B int4 fp16 support (#1784)
Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-08-23 10:00:32 -07:00
gpetters94
7ee3e4ba5d Add stencil_unet_512 support (#1778)
This should fix any remaining issues with stencils and long prompts.
2023-08-22 12:23:46 -04:00
Vivek Khandelwal
05889a8fe1 Add LLaMa2-int4-fp16 support (#1782) 2023-08-22 07:45:50 -07:00
jinchen62
b87efe7686 Fix venv setup for brevitas (#1779) 2023-08-21 11:58:51 -07:00
gpetters94
82b462de3a Fix stencils for long prompts (#1777) 2023-08-19 00:26:51 -07:00
Daniel Garvey
d8f0f7bade replace public with private (#1776)
unload footguns
2023-08-18 14:22:46 -07:00
gpetters94
79bd0b84a1 Fix an issue with diffusers>0.19.3 (#1775) 2023-08-18 14:06:06 -04:00
jinchen62
8738571d1e Adapt the change of brevitas custom op name (#1772) 2023-08-17 14:24:43 -07:00
Gaurav Shukla
a4c354ce54 [version] Pin diffusers==0.19.3
Once the latest works with LORA train, unpin it.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
cc53efa89f [cli] Fix chatbot cli
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
9ae8bc921e [chatbot] Fix chatbot cli and webview warning
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 21:27:10 +05:30
Gaurav Shukla
32eb78f0f9 [chatbot] Fix switching parameters in chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-17 19:14:17 +05:30
Ean Garvey
cb509343d9 Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc.
- fix torch benchmarks in SharkBenchmarkRunner
- generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir
- decouple SD gen from tank/generate_sharktank for now
2023-08-16 17:48:47 -05:00
powderluv
6da391c9b1 update signtool to use /fd certHash 2023-08-15 15:11:40 -07:00
Ean Garvey
9dee7ae652 fix tkinter window (#1766) 2023-08-15 13:23:09 -07:00
Ean Garvey
343dfd901c Update SHARK-Runtime links to SRT (#1765)
* Update nightly.yml

* Update setup_venv.ps1

* Update CMakeLists.txt

* Update shark_iree_profiling.md

* Update setup_venv.sh

* Update README.md

* Update .gitmodules

* Update CMakeLists.txt

* Update README.md

* fix signtool flags

* Update nightly.yml

* Update benchmark_utils.py

* uncomment tkinter launch
2023-08-15 12:40:44 -07:00
Ean Garvey
57260b9c37 (Studio) Add hf-hub to pyinstaller metadata (#1761) 2023-08-14 23:01:50 -05:00
Ean Garvey
18e7d2d061 Enable vae tunings for rdna3. (#1764) 2023-08-14 21:00:14 -07:00
Stanley Winata
51a1009796 Add Forward method to SHARKRunner and fix examples. (#1756) 2023-08-14 19:20:37 -07:00
Daniel Garvey
045c3c3852 enable iree-opt-const-expr-hoisting in vicuna (#1742)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-08-14 18:43:42 -07:00
Ean Garvey
0139dd58d9 Specify max allocation size in IREE compile args. (#1760) 2023-08-14 15:43:09 -05:00
Ean Garvey
c96571855a prevents recompiles for cuda benchmarks + update benchmark_module path (#1759)
* xfail resnet50_fp16

* Fix cuda benchmarks and prevent recompilation.
2023-08-14 15:30:32 -05:00
PhaneeshB
4f61d69d86 add support passing iree flags for LLMs 2023-08-15 00:22:56 +05:30
Phaneesh Barwaria
531d447768 set default allocator for metal device creation (#1755) 2023-08-14 06:17:52 -07:00
Vivek Khandelwal
16f46f8de9 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
c4723f469f Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d804f45a61 Update langchain_requirements.txt 2023-08-14 14:32:19 +05:30
Vivek Khandelwal
d22177f936 Update requirements.txt 2023-08-14 14:32:19 +05:30
George Petterson
75e68f02f4 Remove CUDNN 2023-08-14 14:32:19 +05:30
Gaurav Shukla
4dc9c59611 [chatbot] Add tokens generated per second (#1753) 2023-08-13 11:25:41 -07:00
Gaurav Shukla
18801dcabc [chat] Update chatbot ui
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-13 18:39:22 +05:30
Gaurav Shukla
3c577f7168 [vicuna] fix shard config generator script (#1747)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-10 11:26:03 -07:00
Stefan Kapusniak
f5e4fa6ffe UI/Web - Revert tab order (#1724)
* Revert ui tab order

* Reverts the tab order, so that SD, LLM, and Experimental are grouped
together again as far as is possible.
* Labelled "Generate Sharding Config" as experimental as pressing the
'Get Model Config' errors for me.

* Fix formatting in index.py
2023-08-10 11:25:36 -07:00
powderluv
48de445325 Enable caching and disable vma (#1746)
* Enable caching allocator by default

Going to toggle VMA off too and this is required for performance.  Will have to monitor in the wild reports.

* Disable VMA

Disable VMA
2023-08-10 10:49:44 -07:00
Gaurav Shukla
8e90f1b81a [vicuna] add default config in case of sharded vicuna
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
2023-08-10 21:28:08 +05:30
Vivek Khandelwal
e8c1203be2 Fix vicuna script (#1745) 2023-08-10 06:11:14 -07:00
Vivek Khandelwal
e4d7abb519 Final patch for fixing Langchain token streaming issue (#1744) 2023-08-09 10:09:41 -07:00
powderluv
96185c9dc1 pin safetensors to 0.3.1 (#1740) 2023-08-08 19:24:44 -07:00
powderluv
bc22a81925 re-enable constant folding (#1739)
Tested and works well. (modulo unrelated driver issue)
2023-08-08 17:17:38 -07:00
Eliasj42
5203679f1f Bandaid fix 2 (#1728)
* download all mlirs

* fixed install method

* download all mlirs (#1727)

Co-authored-by: Elias Joseph <elias@nod-labs.com>

* added taggs

* fix name check for file existence

* Remove SD from all_models.csv (#1706)

Removes SD from pytests as it has its own test suite.

* gpt_langchain.py fixes for pydantic (#1722)

* removed dead code

---------

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
2023-08-08 12:14:57 -05:00
Vivek Khandelwal
bf073f8f37 [Langchain] Expand pipelines to fix token streaming issue 2023-08-08 10:27:23 +05:30
Stella Laurenzo
cec6eda6b4 Optimize device enumeration overhead and log details on long operations. (#1734)
* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
2023-08-07 17:20:53 -07:00
Stella Laurenzo
9e37e03741 Clearly differentiate phases of loading modules to better understand if things are taking a long time. (#1733) 2023-08-07 14:03:12 -07:00
Stefan Kapusniak
9b8c4401b5 gpt_langchain.py fixes for pydantic (#1722) 2023-08-07 00:55:38 -07:00
Ean Garvey
a9f95a218b Remove SD from all_models.csv (#1706)
Removes SD from pytests as it has its own test suite.
2023-08-05 15:55:52 -05:00
PhaneeshB
872bd72d0b fix name check for file existence 2023-08-05 21:33:53 +05:30
Eliasj42
fd1c4db5d0 download all mlirs (#1727)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 18:22:06 -05:00
Daniel Garvey
759664bb48 add py files to pyinstaller for shark (#1723) 2023-08-04 14:10:43 -07:00
Daniel Garvey
14fd0cdd87 add missing subprocess import (#1721) 2023-08-04 15:15:22 -05:00
Daniel Garvey
a57eccc997 fix lint (#1720) 2023-08-04 14:54:33 -05:00
Daniel Garvey
a686d7d89f temporarily disable langchain stuff in webui (#1719)
its breaking the exe
2023-08-04 12:48:06 -07:00
Eliasj42
ed484b8253 added functionality for int8 vicuna and 4 shards (#1712)
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-08-04 14:05:05 -05:00
gpetters94
7fe57ebaaf Add vector database and add support on the web UI (#1699) 2023-08-04 13:47:19 -04:00
Nithin Meganathan
c287fd2be8 Add GPU ID's in model_confg.json by default for manual annotation (#1718) 2023-08-04 12:46:27 -05:00
Gaurav Shukla
51ec1a1360 [vicuna] Integrate sharded vicuna in web (#1717)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 11:46:53 -05:00
Gaurav Shukla
bd30044c0b [Shard] Add sharding generation in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-08-04 21:51:14 +05:30
Ean Garvey
c9de2729b2 Add flag for toggling constant folding. (#1714) 2023-08-04 04:55:52 -07:00
Vivek Khandelwal
a5b13fcc2f [Langchain] Patch for fixing streaming of tokens (#1709) 2023-08-03 10:06:49 -07:00
Stefan Kapusniak
6bb329c4af Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b-v1.3 fp16 with 64 GiB (#1702) 2023-08-01 06:07:56 -07:00
Vivek Khandelwal
98fb6c52df Expand pipelines to fix streaming of tokens 2023-07-31 22:11:01 +05:30
Stefan Kapusniak
206c1b70f4 UI/Web: Reorder tabs to separate SD and LLM (#1701)
Shuffle the tabs around so that:

* All the SD tabs are together
* All the LLM tabs are together
* All the experimental tabs are together
2023-07-29 22:25:30 -04:00
PhaneeshB
cdb037ee54 use shark_args for vulkan debug utils flag 2023-07-30 07:54:26 +05:30
PhaneeshB
ce2fd84538 fix cpu device name for SharkStudio 2023-07-30 07:54:26 +05:30
PhaneeshB
4684afad34 update upscalar example 2023-07-28 21:06:28 +05:30
PhaneeshB
8d65456b7a Move vulkan runtime flags to shark_args 2023-07-28 21:06:28 +05:30
PhaneeshB
d6759a852b add vulkan vma alloc flag 2023-07-28 21:06:28 +05:30
Daniel Garvey
ab57af43c1 Couple of fixes for vicuna.py (#1696)
* mega vicuna merge pt 2

* add fallback to ensure compile is called
2023-07-27 15:53:05 -07:00
jinchen62
4d5c55dd9f Fix vicuna script (#1697) 2023-07-27 17:24:26 -05:00
Vivek Khandelwal
07399ad65c [Langchain] Remove unused code (#1698) 2023-07-27 11:59:54 -05:00
Vivek Khandelwal
776a9c2293 Fix for Langchain (#1694)
For CPU, remove max time stopping criteria
Fix web UI issue
2023-07-26 09:00:23 -07:00
Eliasj42
9d399eb988 fixed bug where device_idx was hardcoded (#1693)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-07-25 19:00:13 -05:00
Vivek Khandelwal
927b662aa7 Add Langchain SHARK Compilation support for all paths 2023-07-25 22:15:42 +05:30
Abhishek Varma
47f8a79c75 [MiniGPT4] Add MiniGPT4 to SHARK (#1554)
* [MiniGPT4] Add MiniGPT4 to SHARK

-- This is the first installment of MiniGPT4 in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* Add int8 support for MiniGPT4

-- This commit adds int8 support for MiniGPT4.

Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>

* Update .spec for MiniGPT4's config files

* black format MiniGPT4

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>
2023-07-25 09:42:27 -07:00
Stefan Kapusniak
289f983f41 SD - Implement seed arrays for batch runs (#1690)
* SD Scripts and UI tabs that support batch_count can now take a
string containing a JSON array, or a list of integers, as their seed
input.
* Each batch in a run will now take the seed specified at the
corresponding array index if one exists. If there is no seed at
that index, the seed value will be treated as -1 and a random
seed will be assigned at that position. If an integer rather than
a list or json array has been, everything works as before.
* UI seed input controls are now Textboxes with info lines about
the seed formats allowed.
* UI error handling updated to be more helpful if the seed input is
invalid.
2023-07-24 19:22:34 -07:00
Daniel Garvey
453e46562f mega vicuna merge pt 2 (#1685) 2023-07-24 12:42:20 -05:00
Gaurav Shukla
5497af1f56 [config] Add support for uploading sharding config file in chatbot (#1689)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-07-24 10:18:03 -07:00
Vivek Khandelwal
f3cb63fc9c Fix Langchain multiple device isssue (#1688) 2023-07-24 08:03:46 -07:00
Vivek Khandelwal
d7092aafaa Fix multiple issue for Langchain
This commit fixes the following issue for the Langchain:
1.) Web UI not able to fetch results.
2.) For each query model getting reloaded.
3.) SHARK module not using user provided device and precision.
4.) Create a class for main Langchain code.
5.) Misc issues
2023-07-21 21:56:27 +05:30
Vivek Khandelwal
a415f3f70e Fix Langchain Prompt issue and add web UI support (#1682) 2023-07-21 06:36:55 -07:00
Vivek Khandelwal
c292e5c9d7 Add Langchain CPU support and update requirements 2023-07-20 18:53:34 +05:30
Vivek Khandelwal
03c4d9e171 Add support for Llama-2-70b for web and cli, and for hf_auth_token 2023-07-20 14:57:48 +05:30
jinchen62
3662224c04 Update brevitas requirement (#1677)
also clean up useless args

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 22:03:32 -07:00
Vivek Khandelwal
db3f222933 Revert "Add Llama2 70B option in CLI and WebUI (#1673)" (#1679)
This reverts commit 41e5088908.
2023-07-19 22:02:48 -07:00
Stefan Kapusniak
68b3021325 Fixes cosmetic problems with Gradio 3.37.0 (#1676)
* Fix nod-ai logo having a white border
* Fix control labels having a black background
* Remove extra lower border below Save Prompt checkboxes in Txt2Img UI
2023-07-19 17:28:53 -07:00
AyaanShah2204
336469154d added copy-metadata for pyyaml (#1678) 2023-07-19 17:27:25 -07:00
Abhishek Varma
41e5088908 Add Llama2 70B option in CLI and WebUI (#1673) 2023-07-19 10:41:42 -07:00
PhaneeshB
0a8f7673f4 Add README for CodeGen server 2023-07-19 23:10:23 +05:30
PhaneeshB
c482ab78da fix second vic clearing for low mem device 2023-07-19 23:10:23 +05:30
Vivek Khandelwal
4be80f7158 Add support for the Llama-2 model 2023-07-19 20:57:08 +05:30
AyaanShah2204
536aba1424 unpinned torch_mlir (#1668)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-19 06:28:00 -07:00
Ean Garvey
dd738a0e02 small changes to opt_perf_comparison.py (#1670)
* Use longer prompts for OPT comparison script

* small tweaks
2023-07-19 06:26:50 -07:00
Daniel Garvey
8927cb0a2c set optional vmfb download (#1667) 2023-07-18 10:57:28 -07:00
Daniel Garvey
8c317e4809 fix cli for vicuna (#1666) 2023-07-18 10:03:40 -07:00
Vivek Khandelwal
b0136593df Add support for different compilation paths for DocuChat (#1665) 2023-07-18 09:49:44 -07:00
Vivek Khandelwal
11f62d7fac Minor fixes for MiniLM Training 2023-07-18 17:16:44 +05:30
powderluv
14559dd620 Update DocuChat as experimental (#1660) 2023-07-17 22:12:05 -07:00
AyaanShah2204
e503a3e8d6 fixed joblib import error (#1659) 2023-07-17 12:56:10 -07:00
AyaanShah2204
22a4254adf fixed pyinstaller path for langchain imports (#1658) 2023-07-17 12:19:21 -07:00
Vivek Khandelwal
ab01f0f048 Add Langchain model in SHARK (#1657)
* Add H2OGPT

* Add UI tab for h2ogpt

* Add source files from h2ogpt

* Add the rest of the files

* Add h2ogpt support

* Add SHARK Compilation support for langchain model for cli mode

---------

Co-authored-by: George Petterson <gpetters@protonmail.com>
2023-07-17 09:58:15 -07:00
Phaneesh Barwaria
c471d17cca codegen API (#1655) 2023-07-16 20:00:39 -07:00
Stefan Kapusniak
a2a436eb0c SD - Add repeatable (batch) seeds option (#1654)
* Generates the seeds for all batch_count batches being run up front
rather than generating the seed for a batch just before it is run.
* Adds a --repeatable_seeds argument defaulting to False
* When repeatable_seeds=True, the first seed for a set of batches will
also be used as the rng seed for the subsequent batch seeds in the run.
The rng seed is then reset.
* When repeatable_seeds=False, batch seeding works as currently.
* Update scripts under apps/scripts that support the batch_count
argument to also support the repeatable_seeds argument.
* UI/Web: Adds a checkbox element on each SD tab after batch count/size
for toggling repeatable seeds, and update _inf functions to take
this into account.
* UI/Web: Moves the Stop buttons out of the Advanced sections and next
to Generate to make things not fit quite so badly with the extra UI
elements.
* UI/Web: Fixes logging to the upscaler output text box not working
correctly when running multiple batches.
2023-07-15 16:22:41 -07:00
powderluv
1adb51b29d Update docker README.md 2023-07-15 14:31:56 -07:00
anush elangovan
aab2233e25 Add a dev Ubuntu 22.04 docker image 2023-07-15 16:25:37 +00:00
jinchen62
e20cd71314 Change to a separate pass to unpack quantized weights (#1652) 2023-07-15 04:54:53 -07:00
powderluv
5ec91143f5 add a HF accelerate requirement (#1651) 2023-07-14 05:56:12 -07:00
Ean Garvey
7cf19230e2 add perf comparison script for opt. (#1650) 2023-07-13 13:29:48 -05:00
powderluv
1bcf6b2c5b pin diffusers to 0.18.1 (#1648) 2023-07-13 01:02:24 -07:00
jinchen62
91027f8719 Remove done TODOs, a sup PR for #1644 (#1647) 2023-07-12 23:30:45 -07:00
powderluv
a909fc2e78 add tiktoken to spec file (#1646) 2023-07-12 16:12:02 -07:00
jinchen62
247f69cf9d Apply canonicalize for unpacking int4 (#1644)
- tested it unpacks int4 as expected
- tested it doesn't make difference on int8
2023-07-11 19:41:09 -07:00
PhaneeshB
3b8f7cc231 Add codegen support in UI + lint 2023-07-11 21:58:01 +05:30
PhaneeshB
6e8dbf72bd mlir/vmfb path fixes for vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
38e5b62d80 adapt UI to send model details to pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
1c7eecc981 add codegen support in vic pipeline 2023-07-11 21:58:01 +05:30
PhaneeshB
be417f0bf4 fix precision for fp16 2023-07-11 21:58:01 +05:30
AyaanShah2204
a517e217b0 Added support for building ZIP distributions (#1639)
* added support for zip files

* making linter happy

* Added temporary fix for NoneType padding

* Removed zip script

* Added shared imports file

* making linter happy
2023-07-09 06:45:36 -07:00
Ranvir Singh Virk
9fcae4f808 Metal testing (#1595)
* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
2023-07-08 15:22:53 -07:00
Stefan Kapusniak
788d469c5b UI/Web Refix remaining gradio deprecation warning (#1638) 2023-07-08 13:48:36 -07:00
Stefan Kapusniak
8a59f7cc27 UI/Web add 'open folder' button to output gallery (#1634)
* Adds a button that opens the currently selected subdirectory using
the default OS file manager
* Improve output gallery handling of having images deleted out from
under it.
* Don't show VAE or LoRA lines in parameter info panel when their
value is 'None'
* Use a css class for small icon buttons on the output gallery
tab instead using the same id for multiple buttons
2023-07-08 12:44:59 -07:00
Stefan Kapusniak
1c2ec3c7a2 Some Fixes for Gradio 3.36.1 (#1637)
* Clear .style deprecation warnings.
* Re-remove download button from Nod logos.
* Add work around for `container=false` not doing what it did before on
dropdowns to the output gallery CSS
2023-07-08 11:20:34 -07:00
powderluv
af0f715e20 Unpin gradio 2023-07-08 09:41:14 -07:00
jinchen62
47ec7275e6 Fix brevitas quantize argument (#1633) 2023-07-07 11:30:31 -07:00
powderluv
3a24cff901 change binary names 2023-07-06 23:59:14 -07:00
powderluv
1f72907886 Fix the pyinstaller for chatbots (#1631) 2023-07-06 23:30:01 -07:00
Daniel Garvey
06c8aabd01 remove local-sync from webui (#1629) 2023-07-06 13:58:59 -07:00
Phaneesh Barwaria
55a12cc0c4 cpu name in device (#1628)
* show cpu name in devices

* change device order for chatbot
2023-07-06 12:00:09 -07:00
Ean Garvey
7dcbbde523 Xfail models for data tiling flag changes (#1624) 2023-07-06 06:57:17 -07:00
Abhishek Varma
1b62dc4529 [Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-07-06 06:56:17 -07:00
Daniel Garvey
c5a47887f4 Revert revert negative prompt change (#1625)
* revert default flag changes

* revert revert negative prompt change

* revert revert negative prompt change
2023-07-05 22:09:06 -07:00
Daniel Garvey
c72d0eaf87 revert default flag changes (#1622) 2023-07-05 15:43:26 -05:00
powderluv
c41f58042a Update compile_utils.py (#1617)
* Update compile_utils.py

* Update compile_utils.py

* Update compile_utils.py
2023-07-05 10:06:48 -07:00
xzuyn
043e5a5c7a fix a mistake I made, and more formatting changes, and add ++/Karras (#1619)
* fixed missing line break in `stablelm_ui.py` `start_message`
- also more formatting changes

* fix variable spelling mistake

* revert some formatting cause black wants it different

* one less line, still less than 79

* add ++, karras, and karras++ types of dpmsolver.

* black line length 79

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-05 09:00:16 -07:00
Abhishek Varma
a1b1ce935c int8 e2e for WebUI (#1620) 2023-07-05 07:08:36 -07:00
jinchen62
bc6fee1a0c Add int4/int8 vicuna (#1598) 2023-07-05 07:01:51 -07:00
xzuyn
91ab594744 minor fix, some changes, some additions, and cleaning up (#1618)
* - fix overflowing text (a janky fix)
- add DEISMultistep scheduler as an option
- set default scheduler to DEISMultistep
- set default CFG to 3.5
- set default steps to 16
- add `xzuyn/PhotoMerge` as a model option
- add 3 new example prompts (which work nicely with PhotoMerge)
- formatting

* Set DEISMultistep in the cpu_only list instead

* formatting

* formatting

* modify prompts

* resize window to 81% & 85% monitor resolution instead of (WxH / 1.0625).

* increase steps to 32 after some testing. somewhere in between 16 and 32 is best compromise on speed/quality for DEIS, so 32 steps to play it safe.

* black line length 79

* revert settings DEIS as default scheduler.

* add more schedulers & revert accidental DDIM change
- add DPMSolverSingleStep, KDPM2AncestralDiscrete, & HeunDiscrete.
- did not add `DPMSolverMultistepInverse` or `DDIMInverse` as they only output latent noise, there are a few I did not try adding yet.
- accidentally set `upscaler_ui.py` to EulerDiscrete by default last commit while reverting DEIS changes.
- add `xzuyn/PhotoMerge-inpainting` as an in or out painting model.

* black line length 79

* add help section stuff and some other changes
- list the rest of the schedulers in argument help section.
- replace mutable default arguments.
- increased default window height to 91% to remove any scrolling for the main txt2img page (tested on a 1920x1080 monitor). width is the same as its just enough to have the image output on the side instead of the bottom.
- cleanup
2023-07-04 18:51:23 -07:00
Eliasj42
4015793f84 changed method of compiling vicuna to remove first and second vicuna (#1611)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-07-03 12:12:43 -07:00
Ean Garvey
d63ce76dd8 Use sortable image filenames for SD outputs. (#1528) 2023-07-03 10:30:47 -07:00
Prashant Kumar
1c32915570 Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615) 2023-07-01 08:30:58 -07:00
Ean Garvey
6d286c0609 Enable tuning for rectangle sizes on rdna2. (#1608) 2023-06-30 22:28:24 -07:00
Stefan Kapusniak
7392b22731 UI/Web Reduce animation of default --progress_bars (#1613) 2023-06-30 21:12:10 -07:00
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
powderluv
5ce6001b41 Update stablelm_ui.py to default to fp16 2023-06-23 22:55:47 -07:00
powderluv
501d0ca52e Add sentencepiece to webui for pyinstaller 2023-06-23 22:52:06 -07:00
powderluv
b444528715 Pin torch-mlir for windows too 2023-06-23 19:19:28 -07:00
Ean Garvey
6e6c90f62b Pin torch-mlir and use local-task in OPT. (#1592) 2023-06-23 19:17:05 -07:00
AyaanShah2204
8cdb38496e Final REST API Fixes (#1590)
* fixed outpaint api and added tests

* fixed text2img api

* more elegant generator to subscriptable conversion

* final fixes
2023-06-23 16:46:47 -07:00
powderluv
726d73d6ba Revert "[vicuna] Add streaming of tokens (#1587)" (#1588)
This reverts commit 4d55e51d46.
2023-06-23 10:29:00 -07:00
Gaurav Shukla
4d55e51d46 [vicuna] Add streaming of tokens (#1587)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-23 08:20:46 -07:00
Prashant Kumar
6ef78ee7ba Add cpu compile time flags. (#1585) 2023-06-23 07:23:26 -07:00
jinchen62
4002da7161 Add int4/int8 options to chatbot webui (#1586) 2023-06-23 07:18:34 -07:00
powderluv
ecb5e8e5d8 Update txt2img_ui.py 2023-06-23 06:42:12 -07:00
PhaneeshB
28e0919321 Add AMD cpu device 2023-06-23 18:47:04 +05:30
Daniel Garvey
28f4d44a6b downloader was double downloading (#1580) 2023-06-22 18:30:27 -07:00
AyaanShah2204
97f7e79391 [Blender Integration] Fixed Inpainting REST API (#1577)
* fixed inpaint api

* added inpainting test

* fixed linter errors

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-22 16:08:26 -07:00
Nelson Sharpe
44a8f2f8db Include VAE & LoRA data into PNG metadata (#1573)
* include custom lora and vae data in png metadata

* include pycharm settings

* lint with black
2023-06-22 16:05:54 -07:00
Eliasj42
8822b9acd7 added ability to use config file to shard vicuna (#1565)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-06-22 17:40:35 -05:00
Daniel Garvey
0ca3b9fce3 fix some mmap and vicuna bugs (#1576) 2023-06-22 17:39:55 -05:00
Nithin Meganathan
045f2bb147 Add dispatch-level config file generator for manual annotation (#1566) 2023-06-22 15:11:41 -07:00
Prashant Kumar
a811b867b9 Add shark_eager mode.
-- Eager mode with step by step op compilation and execution.
2023-06-22 22:59:14 +05:30
Abhishek Varma
cdd505e2dd [SharkInference-SharkRuntime] Adds capability to mmap vmfbs
-- This commit is based on [VmModule.mmap() API](https://github.com/openxla/iree/pull/14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-22 20:43:40 +05:30
powderluv
1b0f39107c Move torch_mlir import to the top (#1574) 2023-06-21 22:31:35 -07:00
powderluv
b9b8955f74 exclude vulkan on macos 2023-06-21 22:22:27 -07:00
powderluv
6f7a85eee3 switch to metal backend for CI 2023-06-21 22:17:11 -07:00
Ranvir Singh Virk
18c8e9e51e Metal typo fix (#1572)
* fixing typos for metal changes

* black formating
2023-06-21 21:56:11 -07:00
Daniel Garvey
a202bb466a fp16 fixes for webui (#1571) 2023-06-21 20:24:02 -07:00
Ranvir Singh Virk
07c1e1d712 Adding metal_utils for iree_utils (#1561)
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <avarma094@gmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-21 19:09:03 -07:00
Ranvir Singh Virk
18daec78c8 Added check for python version (#1570)
* Added check for python version

* Update for PYTHON_VERSION_X_Y
2023-06-21 18:56:47 -07:00
Ean Garvey
1a8e2024d6 Exclude non-square sizes from use_tuned on rdna2 (#1568) 2023-06-21 11:36:55 -05:00
AyaanShah2204
d61b6641fb Rest API: Resolved Generator Object not Subscripatable error (#1556) 2023-06-20 19:27:41 -07:00
Phaneesh Barwaria
88cc2423cc Enable Vicuna fp16 cpu (#1562)
* fix second vic mlir gen

* fp16 mlir/vmfb download from shark_tank
2023-06-20 13:43:21 -05:00
Ean Garvey
ccf944c1bd Enable tuner for upscaler unet. (#1563) 2023-06-20 13:40:13 -05:00
Ean Garvey
0def74f520 [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)
* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.
2023-06-20 10:26:36 -07:00
Abhishek Varma
3fb72e192e Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)
-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-20 10:04:17 -07:00
Vivek Khandelwal
855435ee24 Fix for the user input for Falcon pipeline 2023-06-20 18:09:32 +05:30
Elias Joseph
6f9f868fc0 fixed a bug where designating device for vicuna didn't work 2023-06-20 17:09:32 +05:30
powderluv
fb865f1b99 Move to checkout@v3
This will break Windows again but we have to fix it up since the old node.js is now deprecated.
2023-06-19 18:44:36 -07:00
rprasad2
3e5c50f07b changes for tuning (#1542)
* Add tuning sizes for rdna3
2023-06-19 15:29:08 -05:00
powderluv
a544f30a8f Move mega to the shark examples (#1555) 2023-06-19 11:10:51 -07:00
Abhishek Varma
1fe56d460a [MEGABYTE] Add script to compile MEGABYTE through SHARK (#1553)
-- Usage: `python mega_test.py`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-06-19 11:00:35 -07:00
Vivek Khandelwal
fafd713141 Minor change to falcon pipeline 2023-06-19 22:36:32 +05:30
Vivek Khandelwal
015d0132c3 Modify falcon pipeline to add fp16 support (#1551) 2023-06-19 09:57:13 -07:00
powderluv
20ddd96ef7 unpin diffusers (#1550) 2023-06-18 13:45:55 -07:00
powderluv
ee33cfd2d1 Add PIL in main index.py (#1549)
* Add PIL in main index.py

This is to ensure pyinstaller picks it up

* Update index.py
2023-06-18 11:51:44 -07:00
Stefan Kapusniak
a3cba21d5b Fix load of unet512 vmfb fail on get of iree opts (#1546)
* Change retrieval of Iree options used when loading an existing
unet512 vmfb to look up the "unet" options rather than attempt to
find a non-existent set of options for "unet512"

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:42:20 -07:00
Stefan Kapusniak
a7b6ec4095 Fix unet512 always being used when --max_length=77 (#1547)
* Switches a few places in the SD pipeline where an assumption of
max_length=64 was being made, to using the actual max_length
as passed into the pipeline. This prevents unet512 always being
used and producing different images than previously when
--max_length=77
2023-06-18 06:41:25 -07:00
Ean Garvey
d80b087d95 Add PIL hidden imports to sd spec. (#1544)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-18 06:39:08 -07:00
Stefan Kapusniak
297a209608 Remove workarounds for gradio tempfile bugs (#1548) 2023-06-17 19:50:36 -07:00
gpetters94
b204113563 Add UNet512 (#1504)
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-06-17 03:46:25 -04:00
Chi_Liu
f60ab1f4fa Add Deberta to stablehlo in shark tank (#1545) 2023-06-16 13:24:44 -07:00
Surya Jasper
b203779462 Added Adreno target triples to vulkan_utils (#1543) 2023-06-15 16:42:59 -07:00
Stefan Kapusniak
38570a9bbb Some Fixes for update to gradio 3.34.0 (#1538)
* Fixes randomize seed buttons that stopped working.
* Update now deprecated method to set initial colums for output
gallery to the newer undeprecated one.
2023-06-15 01:10:36 -07:00
dependabot[bot]
a5c882f296 Bump gradio from 3.15.0 to 3.34.0 (#1518)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.15.0 to 3.34.0.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.15.0...v3.34.0)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-14 18:13:48 -07:00
Ean Garvey
eb6d11cfed Change mlir dialects for tf tests to stablehlo. (#1535)
* Change mlir dialects for tf tests to stablehlo

* Update shark_runner.py
2023-06-14 10:43:49 -07:00
Vivek Khandelwal
46184a81ac Add Falcon pipeline (#1534) 2023-06-14 09:39:16 -07:00
PhaneeshB
149165a2f0 add multi-device mutli-precision vmfb names 2023-06-14 22:08:24 +05:30
dan
bec82a665f mega vicuna merge
single endpoint in apps/language/models/scripts/vicuna.py
removed main functions from pipelines
replaced divergent utils compile with shark_importer
adds support for different precisions
2023-06-14 19:06:29 +05:30
Ean Garvey
9551490341 Remove deprecared --iree-mhlo-demote-164-to-132 flag usage. (#1533) 2023-06-13 22:40:47 -05:00
Ean Garvey
49b3ecdbca (pytest) don't run redundant tests in cpu suite (#1532) 2023-06-13 22:40:33 -05:00
Ean Garvey
f53e3594c3 OPT Refactor (#1516)
* Change script to 1.3b model and add pytorch comparison

* fix CLI command

* Match OPT transformers model updates + numerics against latest version

* Cleanup OPT sentence completion script.

* Fix formatting and add standalone validation scripts.

* Add minimal OPT wrapper and example with import_with_fx

* Rename OPT full model wrapper.

* Cleanup test scripts for OPT.
2023-06-13 22:40:07 -05:00
Ean Garvey
5562d1dfda Fix xfails for cpu pytest cases (#1527)
Adding cpu-sync and cpu-task device configs was allowing respective tests to bypass the xfail conditional for cpu pytests marked in tank/all_models.csv. This commit updates the conditional to xfail those cases for cpu-sync and cpu-task as well.
2023-06-13 17:01:51 -07:00
Stefan Kapusniak
c7b0c2961e UI/Web Improve output gallery temp file handling (#1531)
* On startup report that cleaning up of temp files is taking place, in
case it takes a long time.
* Have the output gallery tab delete any zero length temporary files
generated by gradio < 3.32.0 for its gallery control whenever it
needs to update that control with images. This prevents such
files multiplying out of control.
2023-06-13 16:25:37 -05:00
Ean Garvey
44273b0791 Fix conditional in transform_fx() (#1530) 2023-06-13 16:24:53 -05:00
Prashant Kumar
0a4c8fcb3e Minor changes in the fx transforms. 2023-06-13 21:23:35 +05:30
Stefan Kapusniak
2fec3c8169 re-indents add_upcast in shark importer (#1523)
* The two with blocks in add_upcast appear to be underindented making
SD 1.4 break on rdna3, I've pushed them out one more tab, and then
everything appears to work again.
2023-06-12 14:41:10 -05:00
Gaurav Shukla
5e7d5930dd [vicuna] Add device and precision propagation in vicuna (#1520)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-12 12:14:43 -05:00
Prashant Kumar
b6dbd20250 Modify the fx transforms. (#1521)
- The bounds are set properly.
- The upcasting and downcasting is done for vicuna.
2023-06-12 09:40:14 -07:00
Nithin Meganathan
34f1295349 Add a model config generator (#1511)
Model config generator takes a PyTorch model as input and generates a JSON file with model layers and other propperties that define sharding on a particular hardware.
2023-06-09 15:32:00 -07:00
Phaneesh Barwaria
1980d7b2c3 Cpu device map (#1515)
* update cpu iree device

* fix vmfb paths vic unsharded
2023-06-09 11:27:02 -05:00
powderluv
2cfacc5051 fix osx torch_mlir (#1513)
* fix osx torch_mlir

* Update index.py

* Update index.py
2023-06-09 00:57:26 -07:00
Phaneesh Barwaria
436f58ddc4 cli using generate and mem fixes (#1509) 2023-06-08 13:13:32 -05:00
Phaneesh Barwaria
6b29bd17c8 Enable compilation vicuna (#1507)
* add cli for unsharded vic

* enable mlir download and compile
2023-06-07 13:08:22 -07:00
Ean Garvey
2c3485ca3e Add standalone OPT sentence completion script. (#1506) 2023-06-07 10:58:03 -07:00
Daniel Garvey
f206ecc635 reenable compilation in vicuna pipeline, add flags (#1505)
* replace vicuna.py backend with pipeline

* add some memory management to fist vicuna compile

reenable compilation
2023-06-07 09:49:27 -07:00
Stefan Kapusniak
a187e05ae6 Prevent having no cuda devices breaking the UI (#1503)
Don't break the UI when the LLM tab only wants cuda devices but there
aren't any.
2023-06-06 11:41:16 -07:00
Gaurav Shukla
8c21960486 [vicuna] Set only cuda devices in vicuna UI for now
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 22:15:20 +05:30
Gaurav Shukla
be62fce676 [vicuna] Fix vicuna chatbot (#1499)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 09:23:32 -07:00
PhaneeshB
f23b778a6c remove old vicuna scripts 2023-06-06 21:35:58 +05:30
PhaneeshB
436edf900d add vic sharded pipeline 2023-06-06 21:35:58 +05:30
Gaurav Shukla
ed58c2553f [vicuna] Integrate vicuna in shark studio
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-06 20:57:48 +05:30
Stefan Kapusniak
f2ca58e844 Add .csv and .json param info to output gallery (#1495) 2023-06-06 07:08:34 -07:00
Ean Garvey
1dbcc736eb [SD] (RDNA2) Enable new tuning for sd1.4 (#1498) 2023-06-06 06:48:58 -07:00
Phaneesh Barwaria
a83808ddc5 Vicuna cuda on A100 40G (#1496)
* vic chat with memory management (precompiled vmfb)

* fix vmfb path and download
2023-06-06 15:10:33 +05:30
Ean Garvey
a07fe80530 Update OPT, ResNet example scripts. (#1492)
* Update API in OPT example.

* fix resnet50 script

* Add OPT1.3b test script.
2023-06-05 20:19:35 -07:00
Ean Garvey
d0ba3ef8fa disable use_tuned on SD1.4 for rdna2 (#1490)
this is a temporary measure while we retune SD1.4 for rdna2. The current config fails during iree-compile.
2023-06-05 19:46:16 -05:00
Stefan Kapusniak
8400529c2c Fix output gallery not using shark_tmp (#1493)
This fix the gallery component of the  output gallery dumping temporary
files into the standard folders rather than shark_tmp so those files never
got cleared out on restart and would build up.
2023-06-05 16:23:49 -05:00
powderluv
7eaee9c242 update SHARK to nodai SHARK 2023-06-05 00:44:49 -07:00
powderluv
8230eebce5 Switch to CPU torch builds for shark.whl 2023-06-05 00:36:03 -07:00
Ean Garvey
6296ea4be9 fix config handling for sd1.4 on rdna2 (#1489) 2023-06-05 00:02:30 -07:00
Ean Garvey
4151ec3a8f (pytest) tag efficientnet, mobilenet as xfails on vulkan (#1488) 2023-06-04 23:22:32 -07:00
powderluv
a2467e8d43 Enable SHARK whl packages 2023-06-04 23:21:22 -07:00
Ean Garvey
e677178bcc Replace RDNA2 SD lowering configs. (#1486) 2023-06-05 00:57:43 -05:00
Anush Elangovan
7ef1bea953 XFAIL some macos tests 2023-06-04 15:27:03 -07:00
Chi_Liu
ad89bb1413 Add distilgpt2 to stablehlo in shark tank (#1481) 2023-06-02 16:44:46 -05:00
Ean Garvey
218ed78c40 Change instances of input_type='mhlo' to 'auto' (#1482) 2023-06-02 16:43:47 -05:00
Stefan Kapusniak
6046f36ab6 UI/Web: Fix upscaler stop button (mostly) (#1479)
* UI/Web: Fix upscaler stop button

* Hook the cancel_sd function up to the Stop button.
* Adds checks for SD_STATE_CANCEL in the upscaler ui inference function.
* Set and check for SD_STATE_IDLE, SD_STATE_CANCEL in the upscaler
pipeline.

* UI/Web: lint fixes for upscaler stop button fix

---------

Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-06-01 22:26:55 -07:00
Foxlum
5915bf7de3 Add to and tweak vulkan configuration environments. (#1475)
* Update vulkan_target_env_utils.py

* Update vulkan_target_env_utils.py

Adjust target environment capabilities.

* Update vulkan_target_env_utils.py

black linted?
2023-06-01 22:25:20 -07:00
Phaneesh Barwaria
f0a4e59758 LLM Pipeline Wrapper (#1477)
* [LLM] Add LLM pipeline

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* add base pipeline and stableLM

* StableLM on UI - full block

* add SLM default model name

* add vicuna with pipeline

* add one token gen api for vic

* Fix stableLM bugs

* debug vic memory

* lint fix

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-31 10:17:20 -07:00
Stefan Kapusniak
1ddef26af5 Web/UI: Add an Output Gallery tab for SD (#1470)
* WebUI: Adds an Output Gallery tab

Adds an new Output Gallery tab to the ui/webui with these features:

* Subdirectory select dropdown listing subdirectories at any depth below
the <output_dir>/generated_imgs directory,
* Large, full height, gallery area displaying the images in the selected
subdirectory. Shows nod logo when no images are in the selected
subdirectory.
* Slider that changes the number of columns of images that the gallery
displays from between 1 to 16 columns (defaults to 4).
* Expandable parameter info panel showing any generation parameters
saved in the file of the selected image for PNGs, alternatively the
image's EXIF data for JPEGs
* Send to buttons for txt2img, img2img, inpaint, outpaint and upscaler.
* Auto update of gallery and gallery label (to show generation status),
when a new image is generated by any of the stable diffusion tabs, and
is outputted to the currently selected subdirectory.
* Command line option for enabling and disabling the output gallery
(defaults to enabled)
* Command line option for following symlinks when getting entries
for the subdirectory list (defaults to off, as Python os.walk doesn't
check for circular references if following symlinks)

* Reformat with black

Reformat changes with black and then adjust some places where black's
formatting then needed some rephrasing of the code to make things
clearer.

* Add back transformers and sd_cancel imports

Adds back the transformers import in index.py needed for .exe
generation. Add comment so it doesn't get mistakenly removed
next time.
Adds back sd_cancel import in upscaler.py that is currently unused
but should be being used for the 'Stop' button.
2023-05-30 13:47:48 -07:00
Chi_Liu
ba8eddb12f Add GPT3/OPT to Stablehlo in shark tank (#1468)
Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-05-29 21:58:39 -07:00
yzhang93
47b346d428 Modify the lowering config format for SPIRVMatmulPromoteVectorize pipeline (#1471) 2023-05-29 21:53:48 -07:00
Ean Garvey
1b4f4f5f4d Fix download path for SD1.4 Unet. (#1469) 2023-05-26 11:59:51 -07:00
Elias Joseph
73cd7e8320 added full vicuna to vicuna.py 2023-05-26 22:06:40 +05:30
Ean Garvey
19c0ae3702 Cleanup SD pipeline utils (#1466) 2023-05-25 12:50:11 -05:00
Ean Garvey
54e57f7771 Revive SD downloads from shark_tank. (#1465) 2023-05-25 12:03:21 -05:00
PhaneeshB
6d64b8e273 vic and slm common generation base 2023-05-25 20:29:41 +05:30
PhaneeshB
a8ea0326f5 correct SLM saved vmfb naming 2023-05-25 20:29:41 +05:30
PhaneeshB
58e9194553 add Lists import 2023-05-25 20:29:41 +05:30
PhaneeshB
eb360e255d remove unused imports 2023-05-25 20:29:41 +05:30
PhaneeshB
a6f88d7f72 refactor mlir compile 2023-05-25 20:29:41 +05:30
Prashant Kumar
8e571d165f Enable cpu f16 dtype tracing for the vicuna model. (#1461) 2023-05-24 09:37:57 -07:00
Ean Garvey
3cddd01b10 Update OPT tokenizer and xfail a few more large tests on macos CI (#1459)
* Update opt_torch_test.py

* Update all_models.csv
2023-05-23 14:36:57 -07:00
Chi_Liu
64c2b2d96b Add gpt2 to stablehlo support in shark tank (#1447)
- Add torch decomposition support when generating shark tank
- Add gpt2 stablehlo
2023-05-22 10:45:51 -07:00
Phaneesh Barwaria
f5ce121988 SLM on Sharkstudio (#1454)
* localize import, fix file reading, device cpu

* extract out model args
2023-05-19 11:21:08 -07:00
Ean Garvey
991f144598 Add iree hidden imports to SD spec (#1456)
* Add iree hidden imports to SD spec

* Update shark_sd_cli.spec
2023-05-19 11:19:16 -07:00
PhaneeshB
09bea17e59 fix #2 SLM in SharkStudio 2023-05-18 00:56:22 +05:30
Daniel Garvey
aefcf80b48 swap to cpu an remove hardcoded paths (#1448)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-17 10:53:34 -07:00
PhaneeshB
512235892e fix SLM for SharkStudio 2023-05-17 22:34:30 +05:30
PhaneeshB
6602a2f5ba add continuous output for CLI 2023-05-17 18:33:46 +05:30
Boian Petkantchin
20114deea0 In MiniLM JAX example verify MLIR result against JAX 2023-05-16 09:54:07 -07:00
Boian Petkantchin
9acf519078 Add option to skip venv creation in setup script 2023-05-16 09:54:07 -07:00
Boian Petkantchin
bdf37b5311 If device/backend is unknown pass it to IREE verbatim 2023-05-16 09:54:07 -07:00
powderluv
8ee2ac89f8 Rename sharded_vicuna_fp32_web.py to vicuna_web.py 2023-05-16 09:41:35 -07:00
powderluv
60cb48be2e Rename sharded_vicuna_fp32.py to vicuna.py 2023-05-16 09:40:51 -07:00
powderluv
86a215b063 Delete sharded_vicunia.py 2023-05-16 09:37:39 -07:00
powderluv
d6e3a9a236 Delete standalone_vicuna.py 2023-05-16 09:37:26 -07:00
Chi_Liu
a0097a1ead Add mlir_type for torch_model_list.csv (#1428)
- Enable stablehlo/tosa mlir output for torch model
- Add BERT stablehlo support
2023-05-15 10:23:54 -07:00
Ean Garvey
a9bae00606 Fix vulkan device selection at compile time and adapt to IREE python changes. (#1407)
* Add support for vulkan device selection at compile time.

* Don't convert device ID to int and fix .exe imports
2023-05-12 23:31:50 -07:00
Daniel Garvey
4731c1a835 prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
2023-05-12 19:11:45 -07:00
Ean Garvey
4c07e47e8c Specify a few models for expected failure on CUDA CI. (#1430) 2023-05-12 17:03:37 -05:00
Gaurav Shukla
e0cc2871bb [SD] Yield 2 tokens at a time in vicuna
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 23:49:01 +05:30
Gaurav Shukla
649f39408b [SD] Fix vicuna response
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 18:06:21 +05:30
Gaurav Shukla
c142297d73 [SD] Fix gradio to 3.22.0 version
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com
2023-05-11 18:05:55 +05:30
Gaurav Shukla
9e07360b00 [SD] Standalone vicuna with web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 17:23:44 +05:30
Gaurav Shukla
7b74c86e42 [SD] Fix SAMPLE_INPUT_LEN import issue
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-11 15:41:43 +05:30
Eliasj42
fa833f8366 fixed spacing issue with chat-bot (#1417)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-10 16:07:50 -07:00
Gaurav Shukla
fcb059aa38 [SD] Integrate vicuna in the web (#1410) 2023-05-10 11:30:22 -07:00
PhaneeshB
517c670f82 vicuna chat cli 2023-05-10 22:55:06 +05:30
Eliasj42
59df14f18b added vicuna demo (#1408)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-09 21:18:20 -07:00
Ean Garvey
6c95ac0f37 Revert dialect registration in model annotator (#1406)
Matches https://github.com/nod-ai/SHARK-Runtime/pull/58
2023-05-09 11:50:19 -07:00
Daniel Garvey
7a4a51ae73 vulkan vic f16 (#1404)
Co-authored-by: dan <dan@nod-labs.com>
2023-05-08 16:46:53 -07:00
powderluv
d816cc015e Revert "added standalone vicuna script (#1399)" (#1402)
This reverts commit 0e4a8ca240.
2023-05-05 16:08:05 -07:00
Eliasj42
54ce3d48ca added standalone vicuna script (#1401)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 18:05:52 -05:00
Eliasj42
0e4a8ca240 added standalone vicuna script (#1399)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-05-05 15:46:05 -07:00
Daniel Garvey
6ca1298675 maximizes window size for webview launch (#1394) 2023-05-04 20:43:06 -07:00
jinchen62
bbef7a6464 Redesign model manager webui (#1391) 2023-05-04 20:41:29 -07:00
Ean Garvey
cdf2d61d53 Remove imports from iree.compiler.transforms from model annotator. (#1392) 2023-05-04 20:40:19 -07:00
Ean Garvey
6c14847d1f xfail some large tests on macOS builder and switch to hash updates. (#1341)
* Update test-models.yml

* Disable large tests on macOS builder
2023-05-04 19:47:03 -05:00
Gaurav Shukla
68ecdd2a73 [SD] Add LoRA as experimental tab
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
Gaurav Shukla
3f4d444d18 [SD] Fix stable LM chatbot
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-05-04 22:30:25 +05:30
m68k-fr
e473d0375b [Web] Models folders cleanup (#1365) 2023-05-03 16:13:20 -05:00
Ean Garvey
e38d96850f Fix input image loading in img2img rest API (#1388) 2023-05-03 15:51:00 -05:00
Gaurav Shukla
fed63dfd4b [SD] Add stableLM chatbot (#1383)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-05-03 15:37:20 -05:00
Boian Petkantchin
eba4d06405 In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-03 10:34:42 -07:00
Boian Petkantchin
4cfba153d2 Add example JAX MiniLM inference (#1380)
* Do not hardcode the name of the VM module in get_iree_module

* Add example JAX MiniLM inference

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-05-02 15:03:54 -07:00
jinchen62
307c05f38d Convert original vae to diffusers (#1382) 2023-05-02 01:27:28 -07:00
jinchen62
696df349cb Fix curl issue (#1369) 2023-04-28 09:31:14 -07:00
jinchen62
cb54cb1348 Add model manager tab for SD webui (#1368) 2023-04-28 02:43:40 -07:00
Daniel Garvey
9bdb86637d add tkinter launch for webui (#1364) 2023-04-27 19:17:55 -05:00
jinchen62
fb6f26517f Fix webui note (#1367) 2023-04-27 16:14:43 -07:00
Chi_Liu
aa8ada9da9 Add support for torch to stablehlo and tosa in shark_importer (#1360) 2023-04-27 08:09:45 -07:00
powderluv
1db906a373 Revert "Add model manager tab for webui (#1359)" (#1362)
This reverts commit 9d1d1617d8.
2023-04-26 22:25:26 -07:00
jinchen62
9d1d1617d8 Add model manager tab for webui (#1359) 2023-04-26 13:38:18 -07:00
jinchen62
7112789cb8 Add support of using civitai model download url (#1357) 2023-04-25 23:39:52 -07:00
jinchen62
d6b8be2849 Add drawing canvas for img2img stencil scribble (#1355) 2023-04-25 14:41:01 -07:00
powderluv
822171277c Revert "[SD] Add FastChat as part of SD WebUI (#1349)" (#1350)
This reverts commit a5ae9d9f02.
2023-04-24 15:22:25 -07:00
Abhishek Varma
a5ae9d9f02 [SD] Add FastChat as part of SD WebUI (#1349)
-- This commit includes FastChat as part of SD WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-24 11:12:58 -07:00
powderluv
09e3f63d5b Fix pascal (#1346)
* Add fp32 for upscaler VAE

* Plumb Pascal vulkan support
2023-04-23 20:28:25 -07:00
powderluv
d60a5a9396 Add fp32 for upscaler VAE (#1345) 2023-04-23 15:27:55 -07:00
m68k-fr
90df0ee365 [Web] Gallery set to a 768px reference for high-end desktop users (#1344) 2023-04-23 11:48:06 -07:00
nirvedhmeshram
133c1bcadd add device to scheduler model names (#1338) 2023-04-22 20:13:56 -05:00
powderluv
caadbe14e9 Revert VAE to use im2col (#1339) 2023-04-22 15:23:41 -07:00
Ean Garvey
5f5823ccd9 Fix inference object imports for SD apps. (#1334) 2023-04-21 13:40:48 -05:00
Vivek Khandelwal
d2f7e03b7e Add StableLM model (#1331) 2023-04-21 09:51:02 -07:00
Gaurav Shukla
0b01bbe479 [SD] Add txt2img/upscaler/inpaint/outpaint Rest API (#1325)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-21 09:06:06 -07:00
yzhang93
25c5fc44ae Modify tuner.py to take vulkan target triple flag (#1328) 2023-04-20 14:31:32 -07:00
Daniel Garvey
7330729c92 enable sd pytest (#1322) 2023-04-19 22:11:30 -05:00
Ean Garvey
ce16cd5431 Create local shark_tank if needed for tuning configs. (#1321)
Now that --clear_all successfully deletes local shark_tank cache, we need to make sure it exists before trying to use it.
2023-04-19 11:44:21 -05:00
Ean Garvey
598dc5f79d Don't dump image data on img2img api call. (#1320) 2023-04-19 21:24:46 +05:30
Abhishek Varma
1f8e332cbe [SD] Fix img2img API bug for custom_vae argument (#1319)
-- https://github.com/nod-ai/SHARK/pull/1314 misses to add `custom_vae`
   parameter to img2img_if's invocation within img2img_api.
-- This commit adds a fix to the same.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-04-19 10:39:52 -05:00
Abhishek Varma
17b9632659 [SD] Adapted SHARK's v1 img2img API for SdPaint + updated Stencil model ID (#1318) 2023-04-19 06:29:36 -07:00
jinchen62
bda92a54ab Fix custom vae path (#1317) 2023-04-18 20:50:43 -07:00
jinchen62
747ed383b1 Add custom vae dropdown in webui (#1314) 2023-04-18 17:24:02 -07:00
Ean Garvey
1afe07c296 Disable winograd on VAE with rdna2 and fix unet tuning. (#1313)
* Disable winograd on VAE with rdna2 and fix unet tuning.

* Fix batch size 1 downloads and clear_all on windows.
2023-04-18 15:55:10 -05:00
jinchen62
b70919b38d Fix memory leak with ondemand (#1312)
support ondemand for outpainting and multi batch_count
2023-04-18 13:03:16 -05:00
m68k-fr
4e513d647f Update list of scheduler available for inferences (#1298) 2023-04-17 22:37:00 -05:00
jinchen62
94cd2a0fed Fix outpainting config (#1310) 2023-04-17 10:48:52 -07:00
Kyle Herndon
606029c01c Fix LoRA device format bug and allow LoRA to resume from a previous training 2023-04-17 13:19:46 +05:30
powderluv
1aa85222e9 Add AMD W7900 target triple (#1304)
This maps to RDNA3
2023-04-16 00:14:21 -07:00
m68k-fr
1b3f468c04 [Web] Style Fixes for Gradio V3.25.0 (#1300) 2023-04-13 18:40:42 -05:00
m68k-fr
35de7e27fa [Web] remove txt2img ui dependencies from png import metadata (#1275) 2023-04-12 07:32:47 -10:00
yzhang93
467f900759 Add auto-tuner to SD apps (#1291) 2023-04-12 09:21:17 -07:00
Ean Garvey
0bd9d582c7 Add documentation for using SHARK with AI-Render (#1296) 2023-04-12 03:09:34 -10:00
jinchen62
428cfe8dae Fix low vram mode issues (#1295)
- add ondemand back to img2img
- workaround memory leak for batch count
2023-04-11 17:59:09 -07:00
Ean Garvey
f17915bedc Fix batch size appending to model name. (#1294)
* Update shark_downloader.py

* Update shark_downloader.py
2023-04-11 15:34:25 -05:00
Gaurav Shukla
1b49b5149a [SD] Add Img2Img rest API
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-11 23:06:58 +05:30
jinchen62
3002793301 Unload clip on demand and workaround memory leak (#1283) 2023-04-10 16:59:03 -07:00
Phaneesh Barwaria
d25ef5529f Add fix for vae fp32 Upscalar (#1284)
- fixes size mismatch error for upscalar vae
2023-04-07 14:36:40 -05:00
Ean Garvey
308856a947 Touch unet if base cfg needed for SD pipeline init (#1281) 2023-04-05 03:02:29 -05:00
m68k-fr
151b4e142f [SD] Fix encoder error for model_max_length not beeing 77 (#1278)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-04-04 22:39:29 -07:00
Ean Garvey
e5a69a7c36 pin diffusers to e47459c (#1279) 2023-04-04 18:29:21 -07:00
m68k-fr
450b6cafc4 [SD] Add weight emphasis to prompts encoder (#1276) 2023-04-04 09:47:04 -07:00
Daniel Garvey
237d26baa2 update model db to reflect changes (#1277)
* remove 1/1 tqdm progress bar

* update model_db to reflect changes
2023-04-04 11:46:55 -05:00
Daniel Garvey
67d6ee1104 remove 1/1 tqdm progress bar (#1274) 2023-04-03 22:30:09 -05:00
Ean Garvey
98b069488e Add tank_version.json (#1272) 2023-04-03 18:36:23 -07:00
jinchen62
e0f227643a Fix webui circular import issue (#1271) 2023-04-03 16:00:10 -07:00
jinchen62
a0af3bb0cb xload and unload models (#1242) 2023-04-03 14:42:18 -07:00
powderluv
2cd61a5b96 strip source map (#1270) 2023-04-03 14:41:32 -07:00
Gaurav Shukla
f49d41a807 [SD] Add Stable diffusion text2image rest API (#1265)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-04-03 12:02:24 -07:00
Ean Garvey
2191fc8952 Separate pytest benchmark modes and fix model updates for SHARK downloader / pytest. (#1264)
* Only xfail windows models in CI

* downloader: make model updates more robust.

* Separate baseline and native benchmarks in pytest.

* Fix native benchmarks

* Fix torchvision model utils.
2023-04-03 08:24:21 -07:00
PhaneeshB
aea7796e60 add gradio client to spec 2023-04-03 18:57:19 +05:30
Abhishek Varma
a376619f1e [SD] Improve vmfb caching algo and retry mechanism (#1248)
-- This commit gets rid of the all-or-nothing vmfb caching mechanism
   and improves the retry mechanism by providing lower-level granularity
   for compiling each model units.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2023-03-31 09:38:14 -07:00
powderluv
02d52bb626 Add Intel ARC A770 target triple (#1263)
This just enables the plumbing. It generates black images.
2023-03-29 14:49:05 -07:00
Abhishek Varma
3b63645f79 [SD] Fix custom model path for WebUI (#1260)
-- This commit fixes custom model path for WebUI.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-29 09:48:11 -07:00
Ean Garvey
d6f740b998 allow pytest to retry getting model artifacts + disable autotuning for pytorch benchmarks (#1257)
* Adds a few xfails to enable macOS builder

* Convert string batch sizes to ints where needed.

* allow pytest to retry getting model artifacts

* Reduce attempts and add assert msg.
2023-03-28 23:38:45 -05:00
Daniel Garvey
594c6b8ea2 fix ckpt dir (#1258) 2023-03-28 14:31:01 -07:00
Ean Garvey
96b1560da5 Make batch size configurable via pytest and fix sharktank generation. (#1227)
* Fix sharktank generation and add batch_size pytest option for torch.

* Disable torch dynamo until py3.11 supported

* Compile torchmodel without dynamo if torch.compile fails

* Use release versions of TF/Keras for importer.

* Pin torchvision and remove debug prints.

* Remove duplicates from torch model list.

* Update generate_sharktank.py

* xfail a few models that fail sharktank generation/ numerics
2023-03-28 14:33:39 -05:00
Abhishek Varma
0ef6a0e234 [SD] Fix Stencil scribble crash by updating image resize (#1255)
-- This commit updates Stencil resize feature to cap the size of
   images within [128,768] as supported by the SD pipeline.
-- This solves the issue of scribble crashing on larger image.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-28 10:13:11 -07:00
Gaurav Shukla
641d535f44 [SD] Fix device path issue for cpu (#1256)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-28 10:09:49 -07:00
Daniel Garvey
5bb7846227 single entry point exe for all cli apps (#1158)
usage:
add --app="img2img" (or "inpaint" "outpaint" "txt2img")
2023-03-28 11:15:21 -05:00
yzhang93
8f84258fb8 Fix check for use_tuned conditions (#1252) 2023-03-27 11:21:25 -07:00
Ean Garvey
7619e76bbd Disable and xfail some models that fail validation/compilation. (#1251)
* Rollback T5 models for torch as the inputs give some issues that aren't trivial to resolve
* xfail efficientnet-b0 on torch+cuda -- see CUDA requesting shared memory size larger than allowed size openxla/iree#12771
2023-03-27 12:42:53 -05:00
Daniel Garvey
9267eadbfa disable openjourney gen for nightly (#1249) 2023-03-27 11:55:34 -05:00
Phaneesh Barwaria
431132b8ee Fix img2img mode switch (#1247)
* add updated scheduler value in global config

* clear scheduler global variable with others
2023-03-27 07:01:22 -07:00
cstueckrath
fb35e13e7a fix Python version detection bug (#1246)
* fix Python version detection bug

* Update setup_venv.ps1
2023-03-27 07:00:40 -07:00
yzhang93
17a67897d1 Add SD v2.1 768x768 tuned model (#1244)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-24 10:39:15 -07:00
Gaurav Shukla
da449b73aa [SD] Disable lora training tab for now (#1241)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-24 09:16:24 -07:00
Kyle Herndon
0b0526699a Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
2023-03-24 01:10:50 -07:00
Boian Petkantchin
4fac46f7bb In models testing fix paths to be relative to the script dir not cwd (#1128)
authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-22 15:26:52 -05:00
Daniel Garvey
49925950f1 fix false positives (#1193) 2023-03-22 15:25:39 -05:00
Thomas
807947c0c8 Remove deprecated cli option iree-hal-cuda-disable-loop-nounroll-wa (#1235) 2023-03-22 12:05:15 -05:00
Abhishek Varma
593428bda4 [SD] Fix for transformers/__init__.py issue in PyInstaller (#1233)
-- This commit fixes the transformers/__init__.py issue in PyInstaller.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-22 08:43:53 -07:00
Abhishek Varma
cede9b4fec [SD] Fix custom_vae as a required parameter in inpaint (#1232) 2023-03-22 04:30:17 -07:00
Prashant Kumar
c2360303f0 Add the int8 quantized model. 2023-03-22 16:28:13 +05:30
jinchen62
420366c1b8 Move schedulers to global obj (#1225) 2023-03-21 22:40:43 -07:00
Ean Garvey
d31bae488c Set iree-input-type to tm_tensor for SD (#1228) 2023-03-21 19:07:31 -07:00
Kyle Herndon
c23fcf3748 Fix incorrect device argument initialization for LoRA training (#1231)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-21 19:07:18 -07:00
jinchen62
7dbbb1726a Fix SD obj not defined if fail to get models from pretrained (#1222) 2023-03-21 07:55:17 -07:00
Abhishek Varma
8b8cc7fd33 [SD] Update LoRA inference to handle various checkpoints (#1215) 2023-03-21 06:52:20 -07:00
Ean Garvey
e3c96a2b9d Move sentencepiece to importer requirements. (#1218) 2023-03-21 00:39:57 -05:00
Ean Garvey
5e3f50647d Set --vulkan_large_heap_block_size default to 2gb. (#1220) 2023-03-20 21:07:09 -07:00
gpetters94
7899e1803a Add fix for attention slicing fp16 (#1217) 2023-03-20 19:11:29 -07:00
mariecwhite
d105246b9c Fix t5 models 2023-03-21 10:39:59 +11:00
mariecwhite
90c958bca2 Add T5-base and T5-large Torch and TF Models (#1116) 2023-03-20 17:32:50 -05:00
mariecwhite
f99903e023 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:22:05 +11:00
mariecwhite
c6f44ef1b3 Add EfficientNet B0 and B7 Torch and TF models 2023-03-21 09:14:45 +11:00
mariecwhite
8dcd4d5aeb Make batch size configurable 2023-03-20 18:03:17 -04:00
Phoenix Meadowlark
d319f4684e Add peak memory reporting for IREE, TF and PyTorch (#1216) 2023-03-20 15:40:49 -05:00
Ean Garvey
54d7b6d83e Generate model artifacts in pytests if they don't exist in the cloud. (#1121)
* Add gen_shark_files fn to shark_downloader for OTF artifact generation

* add generate_sharktank as a tank/ python module.

* Fix some paths in tank generation.
2023-03-20 12:13:19 -05:00
m68k-fr
4a622532e5 [Web] Stop images (#1212) 2023-03-19 14:37:30 -07:00
cstueckrath
650b2ada58 add pytorch_lightning to requirements (#1211)
* add pytorch_lightning to requirements

this will additionally add lightning-utilities and torchmetrics

* Update shark_sd.spec

* Update shark_sd_cli.spec
2023-03-19 12:29:54 -07:00
m68k-fr
f87f8949f3 [Web] CSS fix for gradio V3.22.1 (#1210) 2023-03-19 06:13:59 -07:00
m68k-fr
7dc9bf8148 [Web] Move "stop Batch" button to "Advanced Options" toggle (#1209) 2023-03-18 20:54:42 -07:00
Kyle Herndon
ba48ff8d25 Implement LoRA training and UI for training and UI for inference in img2img, inpaint, outpaint (#1200)
txt2img inference UI is already committed.

Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-17 12:54:56 -07:00
Gaurav Shukla
638840925c [SD] Add support for larger size upscaling (#1204)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-17 10:20:48 -07:00
m68k-fr
b661656c03 [Web] Fix custom model path for upscaler (#1199) 2023-03-16 15:57:23 -07:00
Gaurav Shukla
0225434389 [SD] Add sendTo Upscaler
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7ffe20b1c2 [SD] Release memory used by upscaler when not in use
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
d8f0c4655d [SD] Add Upscaler web
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
Gaurav Shukla
7e8d3ec0df [SD] Add upscalar pipeline
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-16 20:49:19 +05:30
jinchen62
9c08eec565 Clear memory cache when switching model and mode (#1194) 2023-03-15 22:18:26 -07:00
m68k-fr
2d2c523ac5 [Web] Upgrade Gradio to v3.21.0 (#1188)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-03-15 10:14:49 -07:00
Abhishek Varma
f17b3128c0 [SD] Add LoRA inference to SD pipeline (#1189)
-- This commit adds LoRA inference to SD pipeline.
-- It also modifies txt2img to incorporate the new feature.
   img2img, inpaint, outpaint, etc using Unet can also be extended in a
   similar way.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-15 10:13:45 -07:00
Abhishek Varma
7c7e630099 [SD] Add fix for using latest diffusers + add scribble variant to Stencil (#1191)
* [SD] Add Scribble variant in Stencil

-- This commit adds scribble variant in Stencil.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Use latest diffusers

-- This commit points back to the latest diffusers and updates the
   processing script to tackle the Pix2Pix import issue.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-15 10:13:20 -07:00
m68k-fr
2dd1491ec1 [Web] Add clear queue button (#1192) 2023-03-15 10:12:59 -07:00
Daniel Garvey
236357fb61 add missing import for shark_sd.spec (#1190)
L
2023-03-15 09:23:29 -05:00
Phoenix Meadowlark
7bc38719de Add benchmark artifacts to .gitignore (#1186) 2023-03-14 15:19:06 -07:00
Daniel Garvey
bdbe992769 Add IREE_SAVE_TEMPS for import_debug command (#1184)
based on hf_model_id. Works on windows
2023-03-14 11:40:23 -07:00
Abhishek Varma
e6b925e012 [SD] Add Openpose to Stencil + image size issue fix (#1181)
-- This commit adds openpose model variant to stencil.
-- Fixes image size issue.
-- Also includes fix for the .exe bug introduced by https://github.com/nod-ai/SHARK/pull/1175

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-14 10:30:52 -07:00
cstueckrath
771120b76c workaround Gradio issue (#1183)
https://discord.com/channels/973663919757492264/975522729564446740/1085109774758191164
2023-03-14 01:27:24 -07:00
Boian Petkantchin
a8ce7680db Add flag to augment the device allocator (#1182)
Example:
$ python my_app.py --device_allocator caching debug
This will wrap the device allocator with first caching allocator then
debug allocator.

$ python my_app.py --device_allocator caching
Only wrap with caching allocator.

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
2023-03-13 15:49:26 -07:00
Phaneesh Barwaria
b6dcf2401b Stencil perf improvement (#1179)
* remove conditioning strength multiplier

* mod diffusers lib to v0.14.0
2023-03-13 14:37:38 -07:00
Daniel Garvey
62b5a9fd49 generate sharktank for apps dir (#966)
* merge confix resolution

* add support to other scripts

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-03-13 10:54:15 -07:00
m68k-fr
2f133e9d5c Fix png metadata (#1178) 2023-03-12 22:43:39 -07:00
powderluv
f898a1d332 Update README.md 2023-03-12 16:54:42 -07:00
m68k-fr
b94266d2b9 [Web] Randomize seed to -1 (#1176) 2023-03-12 12:42:31 -07:00
m68k-fr
1b08242aaa [Web] Improve dropdowns ux (#1175) 2023-03-12 12:41:51 -07:00
Abhishek Varma
691030fbab [SD] Improve Stencil feature to handle general image sizes
-- Currently stencil feature works with 512x512 images only.
-- This commit relaxes this constraint and adds support for various
   image sizes.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-11 21:48:31 +05:30
m68k-fr
16ad7d57a3 [WebUi] txt2img_ui: Import png metadata (#1147) 2023-03-10 16:26:34 -08:00
Anush Elangovan
c561ebf43c Drop the torch-mlir pin
Seems to work now with top of master
2023-03-10 15:39:04 -08:00
Prashant Kumar
97fdff7f19 Add instructions how to run the LLaMA model. (#1168)
* Add instructions how to run the LLaMA model.

* Update README.md
2023-03-10 12:36:37 -08:00
Anush Elangovan
ce6d82eab2 Fix bloom lint 2023-03-10 11:53:08 -08:00
Abhishek Varma
b8f4b18951 [SD] Use dynamic stencil HF repo id
-- This commit removes the hardcoded HF ID for Stencil and instead
   utilizes a dynamic instantiation of HF model.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-03-10 23:31:45 +05:30
Eliasj42
b23d3aa584 added more memory efficient method to run large bloom models with sharded blooms (#1165)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-03-10 09:32:56 -08:00
Vivek Khandelwal
495670d9b6 Fix SD fine tuning script device arg usage 2023-03-10 18:37:53 +05:30
Boian Petkantchin
815e23a0b8 Update iree-compile flags --iree-llvm-xxx -> --iree-llvmcpu-xxx (#1164) 2023-03-09 11:31:50 -08:00
Boian Petkantchin
783538fe11 Move linting opts from github workflow to config files
This helps development where you can be sure that running locally

black .
flake8 .

will do the same as in the github job.
2023-03-09 10:46:30 -08:00
Boian Petkantchin
996c645f6a In SD don't include device path in vmfb filename
Include only the driver name instead.
2023-03-09 10:45:32 -08:00
m68k-fr
1f7d249a62 Use utf-8 format for imgs_details.csv 2023-03-09 16:15:58 +05:30
jinchen62
7f6c9a2dc2 Add an inpainting option for only masked area (#1154) 2023-03-07 09:46:05 -08:00
Eliasj42
93891984f3 made sharded bloom example more user friendly (#1153)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-03-06 10:23:48 -08:00
Vivek Khandelwal
cc0ef54e0e Fix Stable diffusion fine tuning script 2023-03-06 17:52:16 +05:30
Daniel Garvey
812152485d temporarily xfail tiny convnext macos (#1142) 2023-03-03 13:30:56 -06:00
Vivek Khandelwal
0816fb403a Add Stable diffusion fine tuning script
This commit adds the sd fine tuning script which runs through the
torchdynamo path.
2023-03-03 21:59:00 +05:30
Gaurav Shukla
4f171772be [SD] Fix SD web flags
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-03 21:55:40 +05:30
mariecwhite
a52331d4aa Install IREE pre-releases (#1139) 2023-03-02 23:17:56 -06:00
yzhang93
ad821a1fc8 Use old torch-mlir package to avoid crash on rdna2 (#1137) 2023-03-02 18:16:58 -08:00
Ean Garvey
116b128802 Use nightly shark_tank for test-models (#1133)
* Use nightly shark_tank for test-models

* Update all_models.csv
2023-03-02 12:33:36 -06:00
Gaurav Shukla
b118f183d1 [SD] Fix few things in sendTo feature (#1132) 2023-03-02 09:11:55 -08:00
Gaurav Shukla
911dff16f1 [SD] Add sendTo feature in stable diffusion (#1131)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-03-02 08:42:38 -08:00
Abhishek Varma
de59a66ae4 [SD] Update diffusers to point to the fix for Stencil + add opencv-python (#1130) 2023-03-02 08:19:29 -08:00
Daniel Garvey
23f1468cc6 disable most models on windows pytest (#1125) 2023-03-02 01:37:50 -06:00
jinchen62
080350d311 Make loading custom inpainting models general (#1126) 2023-03-01 22:14:04 -08:00
Phaneesh Barwaria
7f3f92b9d5 remove extra return arg (#1123)
* remove extra return arg

txt2img expects only 3 mlirs

* add venv reqs for stencils
2023-03-01 11:45:24 -08:00
Abhishek Varma
be3cdec290 [SD] Add Stencil feature to SD pipeline (#1111)
* [WIP] Add ControlNet to SD pipeline

-- This commit adds ControlNet to SD pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Add ControlNet to img2img + fix bug for img2img scheduler

-- This commit adds ControlNet execution to img2img.
-- It restructures the addition of ControlNet variants.
-- It also fixes scheduler selecting bug for img2img pipeline.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* add shark models for stencilSD

* Add Stencil controlled SD in img2img pipeline (#1106)

* use shark stencil modules

* adjust diffusers change

* modify to use pipeline

* remove control from unet

* pump stencils through unet

* complete integration in img2img

* fix lint and comments

* [SD] Add ControlNet pipeline + integrate with WebUI + add compiled flow execution

-- This commit creates a dedicated SD pipeline for ControlNet.
-- Integrates it with img2img WebUI.
-- Integrates the compiled execution flow for ControlNet.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

* [SD] Stencil execution

* Remove integration setup

* [SD] Fix args.use_stencil overriding bug + vmfb caching issue

-- This commit fixes args.use_stencil overriding issue which caused
   img2img pipeline to pick wrong set of modules.
-- It also fixes vmfb caching issue to speed up the loading time
   and pick right set of modules based on a mask.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>

---------

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: PhaneeshB <b.phaneesh@gmail.com>
2023-03-01 10:44:40 -08:00
m68k-fr
f09574538c [WebUi] Remove unsupported full_width parameter, Reactivate gallery nav while multiple images are generated 2023-03-01 23:17:12 +05:30
Daniel Garvey
b1113ab551 disable benchmark on windows for pytest (#1100) 2023-02-28 18:10:29 -06:00
powderluv
ef756389e3 Revert "add cv2 and nod diffusers (#1112)" (#1114)
This reverts commit cb17d017df.
2023-02-28 14:31:40 -08:00
Phaneesh Barwaria
cb17d017df add cv2 and nod diffusers (#1112) 2023-03-01 01:33:43 +05:30
Gaurav Shukla
798f231792 [SD] Update metadata info and canvas size (#1109)
* [SD] Save missing metadata in case of img2img and outpaint

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* [SD] Update the canvas size for inpaint/outpaint

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

* [SD] Update output gallery on each inference

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

---------

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-28 11:25:30 -08:00
m68k-fr
7136890da3 [Fix] Unsupported width and height argument error 2023-02-28 23:32:58 +05:30
mariecwhite
d567192fd3 Fix call to Torch Inductor 2023-02-28 00:35:57 -08:00
jinchen62
dcc4025c78 Fix loading custom inpainting models (#1103) 2023-02-27 17:06:09 -08:00
yzhang93
c6c8ec36a1 Enable tuned models for inpainting (#1102) 2023-02-27 16:46:57 -08:00
Quinn Dawkins
1344c0659a Add doc on profiling with Shark (#1101)
* Add doc on profiling with Shark

* Rename doc
2023-02-27 11:31:27 -08:00
powderluv
973f6d20f4 Try pre-pix2pix 2023-02-25 00:09:05 -08:00
powderluv
8b5c9c51e7 Revert "Update diffusers (#1094)" (#1096)
This reverts commit 0064cc2a6e.
2023-02-24 19:27:56 -08:00
jinchen62
bae208bcc4 Fix outpainting params (#1089) 2023-02-24 14:41:32 -08:00
Daniel Garvey
b6c14ad468 Make sd tests output performance metrics into csv (#1085)
* make some paths windows friendly (#1066)

* add csv output to builder script

and reduce number of models tested
2023-02-24 16:27:52 -06:00
powderluv
0064cc2a6e Update diffusers (#1094) 2023-02-24 14:09:19 -08:00
Gaurav Shukla
0a0567e944 [SD] Avoid unnecessary temp file creations (#1092)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-24 10:53:34 -08:00
gpetters94
694b1d43a8 Add attention slicing support (#1087) 2023-02-24 02:43:02 -08:00
Ean Garvey
e7eb116bd2 use tf-nightly for importer (#1077) 2023-02-23 23:14:48 -06:00
yzhang93
596499a08c Disable tuned configs on all inpainting models (#1086) 2023-02-23 13:15:22 -08:00
naveen raj
2a2e460df2 Add DEISMultistep scheduler #1076 (#1084)
* Add DEISMultistep scheduler #1076

* line lenght lint fix
2023-02-23 10:15:05 -08:00
jinchen62
a9039b35ed Add outpainting web UI (#1083) 2023-02-23 01:02:25 -08:00
jinchen62
a01154a507 Add SD outpainting (#1072)
python apps/stable_diffusion/scripts/outpaint.py --prompt="Face of a yellow cat, high resolution, sitting on a park bench" --img_path=test_imgs/overture-creations-5sI6fQgYIuo.png --import_mlir --hf_model_id="stabilityai/stable-diffusion-2-inpainting" --pixels=128 --mask_blur=8 --left --right --top --bottom --steps=20
2023-02-22 23:16:05 -08:00
powderluv
1d9204282d Update README.md 2023-02-22 23:12:41 -08:00
Eliasj42
5ff40a0d2d added an example to run sharded bloom (#1079)
added ability to compile sharded mlir files from hugingface models

Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-02-22 22:48:58 -08:00
jinchen62
fab6d2e4e0 Resize input image and mask for SD inpainting (#1082) 2023-02-22 22:46:59 -08:00
powderluv
abab59c25f Update nightly.yml 2023-02-22 18:44:43 -08:00
powderluv
c25840b585 Update nightly.yml 2023-02-22 18:34:37 -08:00
powderluv
1b3f9125bb Update nightly.yml 2023-02-22 18:23:44 -08:00
powderluv
b5d9f5ba49 Update nightly.yml 2023-02-22 18:20:31 -08:00
powderluv
1c22aa9c8f Resolve __init__.py issues (#1080)
Also drop torchvision. The test passed and didn't fail but
we can't be sure it fixes the __init__.py issue yet.
2023-02-22 18:17:00 -08:00
Daniel Garvey
e1d7fb879c make some paths windows friendly (#1066) 2023-02-22 14:44:55 -06:00
powderluv
e912c42bf0 update the openxla links 2023-02-22 12:10:23 -08:00
powderluv
e6841acf36 Publish nightlies as pre-releases
So stable versions can be marked on the Releases page
2023-02-22 12:05:28 -08:00
Gaurav Shukla
bc4459b6f4 [SD] Add inpainting web UI (#1069)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-22 11:01:18 -08:00
cstueckrath
9b544491e0 Update setup_venv.ps1 (#1073)
* Update setup_venv.ps1

fix a bug that occurs, when Python is installed but no py.exe is available

* Update setup_venv.ps1
2023-02-22 07:52:59 -08:00
m68k-fr
9c5415b598 [WebUi] css fix for Gradio v3.19.0 (#1059)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-21 23:50:54 -08:00
powderluv
040dbc317f unpin diffuser to latest (#1071)
Currently 0.13.x
2023-02-21 23:47:19 -08:00
powderluv
65775046d8 update IREE pip links 2023-02-21 19:31:23 -08:00
Daniel Garvey
b18bc36127 force creation of workdir (#1070) 2023-02-21 18:10:36 -08:00
cstueckrath
f01c526efd Update setup_venv.ps1 (#1064) 2023-02-21 14:13:04 -05:00
Gaurav Shukla
16168ab6b3 [SD] Update need_vae_encode correctly
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-21 20:26:06 +05:30
Gaurav Shukla
4233218629 [SD] Reset args.img_path to None in txt2img to avoid vae_encode
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-21 18:46:15 +05:30
RaINi_
b63fb36dc0 Use path.join for the winograd config directory (#1065) 2023-02-20 22:04:25 -06:00
Daniel Garvey
4e92304b89 remove annoying accelerate warning (#1056)
disables usage of low_cpu_mem_usage=True in from_pretrained() calls.
Can be re-enabled by using flag --low_cpu_mem_usage
defaults to False to avoid spam as we don't include accelerate in our
requirements.txt
2023-02-20 14:46:26 -06:00
Ean Garvey
2ae047f1a8 Update importer/benchmark setup for python3.11 (#1043) 2023-02-20 11:29:00 -06:00
Ean Garvey
6d2a485264 Add --benchmark_dispatches option to pytest. (#800)
* Add --benchmark_dispatches option to pytest.

* Update README.md and fix filepath for dispatch benchmarks
2023-02-19 12:16:18 -06:00
Daniel Garvey
4f045db024 disable anythingv3 until issue is resolved (#1053) 2023-02-18 23:47:21 -05:00
yzhang93
5b33597b6d Enable v1.5 to use tuned configs (#1049) 2023-02-18 16:54:26 -05:00
m68k-fr
962470f610 [WebUi] Minor interface cleanup and Ui cosmetics 2023-02-17 22:00:47 +05:30
cstueckrath
ba8c116380 add KDPM2Discrete and a force flag for setup_venv (#1044)
* add KDPM2Discrete and a force flag for setup_venv

* add KDPM2Discrete and a force flag for setup_venv
also made sure that Python 3.11 is used for the venv as 3.10
doesn't work anymore

* add KDPM2Discrete and a force flag for setup_venv
also made sure that Python 3.11 is used for the venv as 3.10
doesn't work anymore
2023-02-17 07:19:56 -05:00
jinchen62
ad7330eae4 Add inpainting test (#1011) 2023-02-16 22:17:10 -06:00
yzhang93
cf126e4839 Use tuned configs on custom models with ckpt_loc (#1038) 2023-02-16 17:06:21 -08:00
powderluv
c96d25c3e2 Delete stable_diffusion_amd.md
All instructions are common now and on the main page.
2023-02-16 14:57:32 -08:00
powderluv
006aa0dae2 Update README.md 2023-02-16 14:54:00 -08:00
Daniel Garvey
5b204bee86 temporarily xfail microsoft resnet50 (#1037)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-16 16:14:51 -06:00
Phaneesh Barwaria
d98b2afbe9 img2img denoise strength (#1040) 2023-02-16 13:40:20 -08:00
Daniel Garvey
681332ef32 fix tests after default flag changes (#1009)
* fix tests after default flag changes

also adds support for import-mlir

* Update setup_venv.ps1

---------
2023-02-16 12:57:50 -06:00
mariecwhite
c3a4fdcbfc Add bert-large-uncased TF model 2023-02-15 21:42:44 -08:00
mariecwhite
aac5de5b02 Add bert-large-uncased Torch model 2023-02-15 21:25:32 -08:00
powderluv
13a255afad Update nightly.yml 2023-02-15 17:11:38 -08:00
powderluv
3bffda52f9 Pin to latest diffusers (#1031) 2023-02-15 14:23:10 -08:00
Daniel Garvey
d4e62ce557 add an import-mlir fallback in case of failure (#1030)
may not cover all cases. will observet

Co-authored-by: dan <dan@nod-labs.com>
2023-02-15 16:15:23 -06:00
yzhang93
9738483b18 [SD] Map v2_1 to v2_1_base until fix (#1029) 2023-02-15 13:44:41 -08:00
Abhishek Varma
143492fe94 [SD] Add support for standalone Vae checkpoints (#1020)
-- This commit adds support for standalone Vae checkpoints.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-15 12:17:32 -08:00
Gaurav Shukla
ecc5c662c4 [SD] Save output images to different loc every day (#1027)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 12:16:36 -08:00
yzhang93
d973ba191d Add conditions to force use --import_mlir (#1028) 2023-02-15 10:37:09 -08:00
Gaurav Shukla
0198b183a2 [SD] Img2Img works for limited schedulers.
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 23:06:28 +05:30
Gaurav Shukla
0d44a3527b [SD][web] Add strength UI for img2img
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
2147b6a397 [SD] Move some common code to utility
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
6b5b4ba27b [SD] Add batch count in Image2Image
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
Gaurav Shukla
67005bf57c [SD] Update iree-vulkan-target-triple after device switch
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-15 22:47:41 +05:30
PhaneeshB
0430c741c6 add strength param 2023-02-15 20:59:03 +05:30
powderluv
1ce02e365d Update README.md 2023-02-15 01:22:28 -08:00
m68k-fr
eae862adc2 Fix lint and path for gradio_tmp_imgs_folder 2023-02-15 14:27:29 +05:30
drumicube
dffa89524a Save gradio tmp images to shark_tmp folder and clean it at launch 2023-02-15 14:27:29 +05:30
yzhang93
2af1102441 [SD] Merge configs of different max lengthes from the same variant to one config file (#1019) 2023-02-15 00:25:29 -08:00
powderluv
c4b472842a Update stable_diffusion_amd.md 2023-02-14 19:02:20 -08:00
powderluv
750a7d806f update docs to 3.11 2023-02-14 17:12:09 -08:00
powderluv
bc7333f1e5 Remove forcing LLPC setting (#1018)
also fix logo paths
2023-02-14 17:09:03 -08:00
powderluv
55ae50f991 Update inpaint.py 2023-02-14 14:12:05 -08:00
powderluv
a590c331ef Update img2img.py 2023-02-14 14:11:50 -08:00
powderluv
8c241b06cb Update txt2img.py 2023-02-14 14:11:36 -08:00
powderluv
9c072c8068 Update index.py 2023-02-14 14:11:20 -08:00
powderluv
ebd8b5122a Update stable_diffusion_amd.md 2023-02-14 14:09:34 -08:00
powderluv
055e484a40 Update README.md 2023-02-14 14:06:46 -08:00
powderluv
912c4a1d12 Update shark_sd.spec 2023-02-14 13:21:29 -08:00
Abhishek Varma
c203b65bf1 Fix __file__ AttributeError + Remove --enable_stack_trace (#1015) 2023-02-14 07:55:02 -08:00
powderluv
307f0334ee Drop im2col for VAE since it crashes the driver (#1010)
This is for untuned models.
2023-02-13 19:02:51 -05:00
yzhang93
5167df08b9 [SD] Fix cuda OTF annotation (#1008) 2023-02-13 12:32:50 -08:00
Gaurav Shukla
dd2e482214 [SD] Fix multiple call to device check (#1007)
- Also makes the dark theme default.
- Fix custom_vae parameter in img2img.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-13 11:57:52 -08:00
Eliasj42
87fd13d8eb added an example to run sharded bloom (#1003)
Co-authored-by: Elias Joseph <elias@nod-labs.com>
2023-02-13 10:37:47 -08:00
yzhang93
dd423bc6de [SD] Using --compile-to to dump mlir for OTF annotation (#1004)
* [SD] Using --compile-to to dumpmlir for preprocessing

* Use python api for dumping process
2023-02-13 09:17:59 -08:00
powderluv
899cb9cc1f Temporarily disable signing of exe 2023-02-12 20:37:42 -08:00
drumicube
0464c7e558 Add support for command arguments to the WebUi (#1000)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
2023-02-11 19:20:21 -08:00
powderluv
f64e1fb926 Fix dark theme again for exe builds (#1001) 2023-02-11 19:08:17 -08:00
powderluv
ef7d31293d Update tests to 3.11 2023-02-11 15:38:27 -08:00
powderluv
6d54eb68dc update to support 3.11 2023-02-11 15:23:18 -08:00
powderluv
30eb10c990 Update to 3.11 2023-02-11 03:47:14 -08:00
Abhishek Varma
591bbcd058 [SD] Fix vmfb locating bug
-- This commit fixes a bug in vmfb caching due to vae_encoder and also
   involves a minor NFC change in the code.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Abhishek Varma
99aa77d036 [SD] Add a common way to name vmfbs including custom_vae
-- This commit adds a common way to name vmfbs and adds to it `custom_vae`
   support as well.
-- This was required to make a common place to change vmfbs name
   without breaking any feature support AND also tackle the caching
   of vmfbs gracefully.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Abhishek Varma
9c13f1e635 Add custom vae support using --custom_vae flag
-- This commit adds custom vae support to SD wherein the user can
   point to a model's checkpoint file whose Vae needs to be plugged
   into the main model.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-10 23:33:47 +05:30
Gaurav Shukla
24af983cfb [SD] Fix input image type
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 23:27:52 +05:30
Gaurav Shukla
67842a7525 [SD] Fix parameters in img2img
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 22:03:33 +05:30
PhaneeshB
3159a6f3e1 add support for img1img 2023-02-10 21:29:02 +05:30
Gaurav Shukla
b2f3c96835 [SD][web] Add Img2Img UI
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-10 21:27:31 +05:30
jinchen62
6582475955 Add SD inpainting
python apps/stable_diffusion/scripts/inpaint.py --prompt="prompt" --img_path=path/to/img --mask_path=path/to/mask --import_mlir --max_length=77 --hf_model_id="stabilityai/stable-diffusion-2-inpainting"
2023-02-10 15:33:20 +05:30
Anush Elangovan
41ee65b377 Revert "Enable --device_allocator=caching"
This reverts commit 83fe477066.
2023-02-09 23:00:06 -08:00
Anush Elangovan
83fe477066 Enable --device_allocator=caching 2023-02-09 22:58:46 -08:00
246 changed files with 11256 additions and 28078 deletions

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py

View File

@@ -14,12 +14,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@@ -44,31 +44,20 @@ jobs:
body: |
Automatic snapshot release of nod.ai SHARK.
draft: true
prerelease: false
prerelease: true
- name: Build Package
- name: Build Package (api only)
shell: powershell
run: |
./setup_venv.ps1
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
signtool sign /f C:\shark_2023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
python process_skipfiles.py
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio_apionly.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
# GHA windows VM OOMs so disable for now
#- name: Build and validate the SHARK Runtime package
# shell: powershell
# run: |
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
- uses: actions/upload-artifact@v2
with:
path: dist/*
- name: Upload Release Assets
id: upload-release-assets
uses: dwenegar/upload-release-assets@v1
@@ -76,7 +65,8 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: ./dist/*
assets_path: ./dist/nodai*
#asset_content_type: application/vnd.microsoft.portable-executable
- name: Publish Release
id: publish_release
@@ -85,80 +75,3 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://iree-org.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -1,162 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Models on Shark Runtime
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
suite: [cpu,cuda,vulkan]
python-version: ["3.10"]
include:
- os: ubuntu-latest
suite: lint
exclude:
- os: ubuntu-latest
suite: vulkan
- os: ubuntu-latest
suite: cuda
- os: ubuntu-latest
suite: cpu
- os: MacStudio
suite: cuda
- os: MacStudio
suite: cpu
- os: icelake
suite: vulkan
- os: icelake
suite: cuda
- os: a100
suite: cpu
- os: 7950x
suite: cpu
- os: 7950x
suite: cuda
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
if: matrix.os != '7950x'
- name: Set Environment Variables
if: matrix.os != '7950x'
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
run: |
# See https://github.com/actions/setup-python/issues/433
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake'
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
#cache: 'pip'
#cache-dependency-path: |
# **/requirements-importer.txt
# **/requirements.txt
- uses: actions/checkout@v2
if: matrix.os == '7950x'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --line-length 79 --check .
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
- name: Validate Models on NVIDIA GPU
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
pytest --benchmark -k vulkan -s
type bench_results.csv
- name: Validate Stable Diffusion Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
./shark.venv/Scripts/activate
python build_tools/stable_diffusion_testing.py --device=vulkan

85
.github/workflows/test-studio.yml vendored Normal file
View File

@@ -0,0 +1,85 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Validate Shark Studio
on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:
# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'
- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
# Disabled due to hang when exporting test llama2
# python apps/shark_studio/tests/api_test.py

33
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb
# C extensions
*.so
@@ -157,17 +159,20 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# vscode related
.vscode
# Shark related artefacts
# Shark related artifacts
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs
# ORT related artefacts
cache_models/
@@ -178,7 +183,29 @@ generated_imgs/
# Custom model related artefacts
variants.json
models/
/models/
*.safetensors
# models folder
apps/stable_diffusion/web/models/
# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb
# Stencil annotators.
stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/
# Webview2 runtime artefacts
EBWebView/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SHARK-Runtime.git
url =https://github.com/nod-ai/SRT.git
branch = shark-06032022

View File

@@ -1,3 +0,0 @@
[style]
based_on_style = google
column_limit = 80

123
README.md
View File

@@ -2,18 +2,20 @@
High Performance Machine Learning Distribution
*We are currently rebuilding SHARK to take advantage of [Turbine](https://github.com/nod-ai/SHARK-Turbine). Until that is complete make sure you use an .exe release or a checkout of the `SHARK-1.0` branch, for a working SHARK*
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download this specific driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree). Latest drivers may not work.
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
@@ -22,27 +24,46 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Install the Driver from (Prerequisites)[https://github.com/nod-ai/SHARK#install-your-hardware-drivers] above
Download the latest .exe https://github.com/nod-ai/SHARK/releases.
Download the [stable release](https://github.com/nod-ai/shark/releases/latest) or the most recent [SHARK 1.0 pre-release](https://github.com/nod-ai/shark/releases).
Double click the .exe and you should have the [UI]( http://localhost:8080/?__theme=dark) in the browser.
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
If you have custom models (ckpt, safetensors) put in a `models/` directory where the .exe is.
If you have custom models put them in a `models/` directory where the .exe is.
Enjoy.
Enjoy.
Some known AMD Driver quirks and fixes with cursors are documented [here](https://github.com/nod-ai/SHARK/blob/main/apps/stable_diffusion/stable_diffusion_amd.md ).
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
### Windows 10/11 Users
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
## Check out the code
```shell
@@ -50,13 +71,21 @@ git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Switch to the Correct Branch (IMPORTANT!)
Currently SHARK is being rebuilt for [Turbine](https://github.com/nod-ai/SHARK-Turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `SHARK-1.0` branch and use the stable code.
```shell
git checkout SHARK-1.0
```
The following setup instructions assume you are on this branch.
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
#### Allow the install script to run in Powershell
```powershell
@@ -72,21 +101,20 @@ set-executionpolicy remotesigned
```shell
./setup_venv.sh
source shark.venv/bin/activate
source shark1.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
(shark1.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark1.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
(shark1.venv) > cd apps/stable_diffusion/web
(shark1.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
@@ -100,21 +128,20 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark1.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
```shell
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
</details>
The output on a 7900XTX would like:
The output on a AMD 7900XTX would look something like:
```shell
Stats for run 0:
```shell
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
@@ -129,7 +156,7 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
<details>
@@ -140,7 +167,7 @@ Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any
This step sets up a new VirtualEnv for Python
```shell
python --version #Check you have 3.10 on Linux, macOS or Windows Powershell
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
python -m venv shark_venv
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
@@ -154,10 +181,10 @@ python -m pip install --upgrade pip
### Install SHARK
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.
@@ -189,10 +216,10 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
<details>
<summary>Development, Testing and Benchmarks</summary>
If you want to use Python3.10 and with TF Import tools you can use the environment variables like:
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
### Run any of the hundreds of SHARK tank models via the test framework
@@ -201,15 +228,15 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
# Or a pytest
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
for Torch-MLIR.
### How to use your locally built Torch-MLIR with SHARK
How to use your locally built Torch-MLIR with SHARK:
```shell
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
2.) Run `pip uninstall torch-mlir`.
@@ -227,15 +254,20 @@ Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
```
shark_module = SharkInference(
mlir_model,
func_name,
device=args.device,
mlir_dialect="tm_tensor",
dispatch_benchmarks="all",
@@ -246,14 +278,14 @@ shark_module = SharkInference(
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
- A .txt file containing benchmark output
See tank/README.md for instructions on how to run model tests and benchmarks from the SHARK tank.
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
</details>
@@ -278,7 +310,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
@@ -301,15 +333,20 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
```
</details>
## Examples Using the REST API
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:
SHARK is maintained to support the latest innovations in ML Models:
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
@@ -335,7 +372,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
bugs, and other work tracking
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
discussions with the core team and collaborators
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
Announcements, general and low-priority discussion
@@ -350,7 +387,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
</details>
## License
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.

View File

@@ -0,0 +1,107 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)
class control_adapter:
def __init__(
self,
model: str,
):
self.model = None
def export_control_adapter_model(model_keyword):
return None
def export_xl_control_adapter_model(model_keyword):
return None
class preprocessors:
def __init__(
self,
model: str,
):
self.model = None
def export_controlnet_model(model_keyword):
return None
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}
class PreprocessorModel:
def __init__(
self,
hf_model_id,
device="cpu",
):
self.model = hf_model_id
self.device = device
def compile(self):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image),
100,
200,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
return None, None

View File

@@ -0,0 +1,130 @@
import importlib
import os
import signal
import sys
import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
)
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
warnings.filterwarnings(
action="ignore", category=FutureWarning, module="huggingface-hub"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="huggingface-hub"
)
# import gradio # noqa: F401
# startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# config_tmp()
# clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
)
# Create custom models folders if they don't exist
create_model_folders()
# initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
app.add_middleware(CORSMiddleware, **cors_options)
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f"Interrupted with signal {sig} in {frame}")
dumpstacks()
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -0,0 +1,475 @@
from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 1024,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<s>", "</s>"
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
"""
def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
return history
class LanguageModel:
def __init__(
self,
model_name,
hf_auth_token=None,
device=None,
quantization="int4",
precision="",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
self.driver = self.backend
if "cpu" in device:
self.device = "cpu"
self.backend = "llvm-cpu"
self.driver = "local-task"
print(f"Selected {self.backend} as IREE target backend.")
self.precision = "f32" if "cpu" in device else "f16"
self.quantization = quantization
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
self.external_weight_file = None
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join(
[
self.safe_name,
self.precision,
]
)
if self.quantization != "None":
self.file_spec += "_" + self.quantization
if external_weights in ["safetensors", "gguf"]:
self.external_weight_file = get_resource_path(
os.path.join("..", self.file_spec + "." + external_weights)
)
else:
self.external_weights = None
self.external_weight_file = None
if streaming_llm:
# Add streaming suffix to file spec after setting external weights filename.
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm
self.tempfile_name = get_resource_path(
os.path.join("..", f"{self.file_spec}.tempfile")
)
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = str(
get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
)
)
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
)
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
)
self.external_weight_file = str(self.external_weight_file)
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
# hf_model_id-specific global flags currently in model map.
flags = []
if "cpu" in self.backend:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
]
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
)
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
def sanitize_prompt(self, prompt):
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
return prompt
else:
return f"{B_INST} {prompt} {E_INST}"
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
def format_out(results):
return torch.tensor(results.to_host()[0][0])
history = []
for iter in range(self.max_tokens):
if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token_len = input_tensor.shape[-1]
device_inputs = [
ireert.asdevicearray(self.runner.config.device, input_tensor)
]
if self.first_input or not self.streaming_llm:
st_time = time.time()
token = self.model["run_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
self.first_input = False
else:
st_time = time.time()
token = self.model["run_cached_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
history.append(format_out(token))
while (
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token = self.model["run_forward"](token)
history.append(format_out(token))
total_time = time.time() - dec_time
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False
history.append(int(token))
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt
import apps.shark_studio.web.utils.globals as global_obj
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
if is_chat_completion_api:
print(f"message -> role : {InputData['messages'][0]['role']}")
print(f"message -> content : {InputData['messages'][0]['content']}")
else:
print(f"prompt : {InputData['prompt']}")
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
device_id = None
if not global_obj.get_llm_obj():
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
if "cuda" in device:
device = "cuda"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "cpu" in device:
device = "cpu"
precision = "fp32"
else:
print("unrecognized device")
llm_model = LanguageModel(
model_name=model_name,
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
use_system_prompt=True,
streaming_llm=False,
)
global_obj.set_llm_obj(llm_model)
else:
llm_model = global_obj.get_llm_obj()
llm_model.max_tokens = max_tokens
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = append_user_prompt(
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
for res_op, _ in llm_model.chat(prompt):
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
print("model loaded")
for i in lm.chat("hi, what are you?"):
print(i)

579
apps/shark_studio/api/sd.py Normal file
View File

@@ -0,0 +1,579 @@
import gc
import torch
import gradio as gr
import time
import os
import json
import numpy as np
import copy
import importlib.util
import sys
from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.img_processing import (
save_output_img,
)
from subprocess import check_output
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"vae": None,
"pipeline": None,
}
def load_script(source, module_name):
"""
reads file source and loads it as a module
:param source: file to load
:param module_name: name of module to register in sys.modules
:return: loaded module
"""
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
class StableDiffusion:
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
def __init__(
self,
base_model_id,
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
progress=gr.Progress(),
):
progress(0, desc="Initializing pipeline...")
self.ui_device = device
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
"custom_pipeline",
)
self.turbine_pipe = custom_module.StudioPipeline
self.dynamic_steps = False
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)
self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
progress(0.5, desc="Initializing pipeline...")
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
)
progress(1, desc="Pipeline initialized!...")
gc.collect()
def prepare_pipe(
self,
custom_weights,
adapters,
embeddings,
is_img2img,
compiled_pipeline,
progress=gr.Progress(),
):
progress(0, desc="Preparing models...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
custom_weights,
)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(
diffusers_weights_path,
"unet",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")
if key in ["mmdit"]:
mmdit_weights_path = os.path.join(
diffusers_weights_path,
"mmdit",
"diffusion_pytorch_model_fp16.safetensors",
)
weights[key] = save_irpa(mmdit_weights_path, "mmdit.")
elif key in ["clip", "prompt_encoder", "text_encoder"]:
if not self.is_sdxl and not self.is_custom:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
elif self.is_sdxl:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
clip_2_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.safetensors",
)
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]
elif self.is_custom:
clip_g_path = os.path.join(
diffusers_weights_path,
"text_encoder",
"model.fp16.safetensors",
)
clip_l_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.fp16.safetensors",
)
t5xxl_path = os.path.join(
diffusers_weights_path,
"text_encoder_3",
"model.fp16.safetensors",
)
weights[key] = [
save_irpa(clip_g_path, "clip_g.transformer."),
save_irpa(clip_l_path, "clip_l.transformer."),
save_irpa(t5xxl_path, "t5xxl.transformer."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
progress(0.5, desc=f"Artifacts ready!")
progress(0.75, desc=f"Loading models and weights...")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
progress(1, desc="Pipeline loaded! Generating images...")
return
def generate_images(
self,
prompt,
negative_prompt,
image,
strength,
guidance_scale,
seed,
ondemand,
resample_type,
control_mode,
hints,
progress=gr.Progress(),
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
return_imgs=True,
)
return img
def shark_sd_fn(
prompt,
negative_prompt,
sd_init_image: list,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: list,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
seed_increment: str | int = 1,
output_type: str = "png",
# progress=gr.Progress(),
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
adapters = {}
is_controlled = False
control_mode = None
hints = []
num_loras = 0
import_ir = True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
model
],
"strength": controlnets["strength"][i],
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
}
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
submit_pipe_kwargs.pop("steps")
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
if (
not global_obj.get_prep_kwargs()
or global_obj.get_prep_kwargs() != submit_prep_kwargs
):
global_obj.set_prep_kwargs(submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
if submit_run_kwargs["seed"] in [-1, "-1"]:
submit_run_kwargs["seed"] = randint(0, 4294967295)
seed_increment = "random"
# print(f"\n[LOG] Random seed: {seed}")
# progress(None, desc=f"Generating...")
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
for batch in range(batch_size):
if output_type == "png":
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
if batch_count > 1:
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)
return (generated_imgs, "")
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
sd_kwargs[key] = int(sd_kwargs[key])
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def get_next_seed(seed, seed_increment: str | int = 10):
if isinstance(seed_increment, int):
# print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
return int(seed + seed_increment)
elif seed_increment == "random":
seed = randint(0, 4294967295)
# print(f"\n[LOG] Random seed: {seed}")
return seed
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
def cancel_sd():
print("Inject call to cancel longer API calls.")
return
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def safe_name(name):
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, cmd_opts.default_config))
)
sd_kwargs = json.loads(sd_json)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -0,0 +1,288 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"
def get_supported_device_list():
return list(_IREE_DEVICE_MAP.keys())
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
def iree_target_map(device):
if "://" in device:
device = device.split("://")[0]
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan-spirv",
"metal": "metal",
"rocm": "rocm",
"hip": "rocm",
"intel-gpu": "opencl-spirv",
}
def get_available_devices():
return ["rocm", "cpu"]
def get_devices_by_name(driver_name):
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
# set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
# cpu_device = get_devices_by_name("cpu-sync")
# available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
# from shark.iree_utils.vulkan_utils import (
# get_all_vulkan_devices,
# )
# vulkaninfo_list = get_all_vulkan_devices()
# vulkan_devices = []
# id = 0
# for device in vulkaninfo_list:
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
# id += 1
# if id != 0:
# print(f"vulkan devices are available.")
# available_devices.extend(vulkan_devices)
# metal_devices = get_devices_by_name("metal")
# available_devices.extend(metal_devices)
# cuda_devices = get_devices_by_name("cuda")
# available_devices.extend(cuda_devices)
# hip_devices = get_devices_by_name("hip")
# available_devices.extend(hip_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", igpu_name
)
break
return available_devices
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = raw_device if "=>" not in raw_device else raw_device.split("=>")[1].strip()
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["hip", "rocm", "vulkan"]:
device_id = None
if device in ["hip", "rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
def parse_device(device_str, target_override=""):
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
rt_device = f"{rt_driver}://{device_id}"
else:
rt_device = rt_driver
if target_override:
if "cpu" in device_str:
rt_device = "local-task"
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "llvm-cpu":
if "Ryzen 9" in device_str:
return target_backend, "local-task", "znver4"
else:
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
# TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
"6900": "gfx1030",
"7900": "gfx1100",
"MI300X": "gfx942",
"MI300A": "gfx940",
"MI210": "gfx90a",
"MI250": "gfx90a",
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
return None
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
del driver
return device_list_src
# def get_device_mapping(driver, key_combination=3):
# """This method ensures consistent device ordering when choosing
# specific devices for execution
# Args:
# driver (str): execution driver (vulkan, cuda, rocm, etc)
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Returns:
# dict: map to possible device names user can input mapped to desired
# combination of name/path.
# """
# driver = iree_device_map(driver)
# device_list = get_all_devices(driver)
# device_map = dict()
# def get_output_value(dev_dict):
# if key_combination == 1:
# return f"{driver}://{dev_dict['path']}"
# if key_combination == 2:
# return dev_dict["name"]
# if key_combination == 3:
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# # mapping driver name to default device (driver://0)
# device_map[f"{driver}"] = get_output_value(device_list[0])
# for i, device in enumerate(device_list):
# # mapping with index
# device_map[f"{driver}://{i}"] = get_output_value(device)
# # mapping with full path
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
# return device_map
# def get_opt_flags(model, precision="fp16"):
# iree_flags = []
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
# iree_flags.append(
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
# )
# if "rocm" in cmd_opts.device:
# from shark.iree_utils.gpu_utils import get_iree_rocm_args
# rocm_args = get_iree_rocm_args()
# iree_flags.extend(rocm_args)
# if cmd_opts.iree_constant_folding == False:
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
# iree_flags.append(
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
# )
# if cmd_opts.data_tiling == False:
# iree_flags.append("--iree-opt-data-tiling=False")
# if "vae" not in model:
# # Due to lack of support for multi-reduce, we always collapse reduction
# # dims before dispatch formation right now.
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
# return iree_flags

View File

@@ -0,0 +1,152 @@
import os
import json
import re
import requests
import torch
import safetensors
from shark_turbine.aot.params import (
ParameterArchiveBuilder,
)
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return path_to_diffusers
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
if precision == "fp16":
pipe.to(dtype=torch.float16)
pipe.save_pretrained(path_to_diffusers)
del pipe
print("Loading complete")
return path_to_diffusers
def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
for key in weights.keys():
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])
if "safetensors" in weights_path:
irpa_file = weights_path.replace(".safetensors", ".irpa")
elif "irpa" in weights_path:
irpa_file = weights_path
else:
return Exception(
"Invalid file format. Please provide a .safetensors or .irpa file."
)
archive.save(irpa_file)
return irpa_file
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
return converted_vae_checkpoint
def process_custom_pipe_weights(custom_weights):
if custom_weights != "":
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
custom_weights_params = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():
print(f"downloading civitai model from {url} to {destination_path}")
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -0,0 +1,185 @@
import os
import sys
import torch
import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
get_path_stem,
)
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
(
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0
),
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_checkpoint_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}

View File

@@ -0,0 +1,204 @@
import os
import re
import json
import torch
import numpy as np
from csv import DictWriter
from PIL import Image, PngImagePlugin
from pathlib import Path
from datetime import datetime as dt
from base64 import decode
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}
elif "progress" in extra_info.keys():
extra_info.pop("progress")
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = extra_info["base_model_id"]
if extra_info["custom_weights"] not in [None, "None"]:
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
img_vae = None
if extra_info["custom_vae"]:
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
img_loras = None
if extra_info["embeddings"]:
img_lora = []
for i in extra_info["embeddings"]:
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = ", ".join(img_lora)
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
pngInfo.add_text(
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
if resampler_type in resamplers:
resampler = resamplers[resampler_type]
else:
resampler = resamplers["Nearest Neighbor"]
new_image = image.resize((n_width, n_height), resampler=resampler)
return new_image, n_width, n_height
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img

View File

@@ -0,0 +1,37 @@
import sys
class Logger:
def __init__(self, filename, filter=None):
self.terminal = sys.stdout
self.log = open(filename, "w")
self.filter = filter
def write(self, message):
for x in message.split("\n"):
if self.filter in x:
self.log.write(message)
else:
self.terminal.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
def logger_test(x):
print("[LOG] This is a test")
print(f"This is another test, without the filter")
return x
def read_sd_logs():
sys.stdout.flush()
with open("shark_tmp/sd.log", "r") as f:
return f.read()
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")

View File

@@ -0,0 +1,205 @@
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
from pathlib import Path
import gc
import os
class SharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
# - preparing weights for an inference job
# - loading weights for an inference job
# - utilites like benchmarks, tests
def __init__(
self,
model_map: dict,
base_model_id: str,
static_kwargs: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
# The weights aren't static here anymore so this function should be a part of pipeline
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.pipe_map[submodel] = {}
self.get_precompiled(self.pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
return
elif "vmfb_path" in self.pipe_map[submodel]:
return
elif submodel not in self.tempfiles:
print(
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
)
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
else:
ireec_flags = (
self.model_map[submodel]["ireec_flags"]
if "ireec_flags" in self.model_map[submodel]
else []
)
weights_path = self.get_io_params(submodel)
if weights_path:
ireec_flags.append("--iree-opt-const-eval=False")
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
)
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = os.path.join(
self.tmp_dir, f"{submodel}.torch.tempfile"
)
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.pipe_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=weights_path,
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return
def run(self, submodel, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
inp = [
ireert.asdevicearray(
self.iree_module_dict[submodel]["config"].device, input
)
for input in inputs
]
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args

View File

@@ -0,0 +1,376 @@
from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe,
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -0,0 +1,118 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: switch over to turbine and run all on GPU
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
# )
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
# )
# schedulers["DPMSolverMultistepKarras"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# use_karras_sigmas=True,
# )
# )
# schedulers["DPMSolverMultistepKarras++"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# )
# )
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2AncestralDiscrete"] = (
# KDPM2AncestralDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# )
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
return schedulers
def export_scheduler_model(model):
return "None", "None"
scheduler_model_map = {
# "PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -0,0 +1,66 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds

View File

@@ -0,0 +1,793 @@
import argparse
import os
from pathlib import Path
from apps.shark_studio.modules.img_processing import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompt",
nargs="+",
default=[
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=2,
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=str,
default=-1,
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=0,
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for " "img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The number of steps to train.",
)
##############################################################################
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding.",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
# Model Config and Usage Params
##############################################################################
p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batches to be generated with random seeds in " "single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
"--custom_weights",
type=str,
default="",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--lowvram",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--external_weights",
type=str,
default=None,
help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations " "for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
p.add_argument(
"--quantization",
type=str,
default="None",
help="Quantization to be used for api-exposed model.",
)
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--defaults",
default="sdxl-turbo.json",
type=str,
help="Path to the default API request .json file. Works for CLI and webui.",
)
p.add_argument(
"--webui",
default=True,
action=argparse.BooleanOptionalAction,
help="controls whether the webui is launched.",
)
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during " "image generation.",
)
p.add_argument(
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)
p.add_argument(
"--config_dir",
type=str,
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)
p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--configs_path",
default=None,
type=str,
help="Path to .json config directory.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
p.add_argument(
"--api_log",
default=False,
action=argparse.BooleanOptionalAction,
help="Enables Compatibility API logging.",
)
##############################################################################
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file.",
)
##############################################################################
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
cmd_opts, unknown = p.parse_known_args()
if cmd_opts.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,106 @@
import time
import argparse
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(
self.original_base_category + self.category,
elapsed_for_subcategroy,
)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer:
def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ""
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
self.records[category] += amount
def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
if self.print_log and not disable_log:
print(
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
)
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self):
res = f"{self.total:.1f}s"
additions = [
(category, time_taken)
for category, time_taken in self.records.items()
if time_taken >= 0.1 and "/" not in category
]
if not additions:
return res
res += " ("
res += ", ".join(
[f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
)
res += ")"
return res
def dump(self):
return {"total": self.total, "records": self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--log-startup",
action="store_true",
help="print a detailed log of what's happening at startup",
)
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None

View File

@@ -0,0 +1,48 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,45 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports_apionly import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd3_server',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,62 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark-turbine", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += [
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
]
# hidden imports for pyinstaller
hiddenimports = ["apps", "shark-turbine"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]

View File

@@ -0,0 +1,46 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
# hidden imports for pyinstaller
hiddenimports = ["apps", "shark-turbine"]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]

View File

@@ -0,0 +1,58 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import logging
import unittest
import json
import gc
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
from apps.shark_studio.web.utils.file_utils import get_resource_path
# class SDAPITest(unittest.TestCase):
# def testSDSimple(self):
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.shark_studio.web.utils.globals as global_obj
# global_obj._init()
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
# sd_kwargs = json.loads(sd_json)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
# for i in shark_sd_fn_dict_input(sd_kwargs):
# print(i)
class LLMAPITest(unittest.TestCase):
def test01_LLMSmall(self):
lm = LanguageModel(
"TinyPixel/small-llama2",
hf_auth_token=None,
device="cpu",
precision="fp32",
quantization="None",
streaming_llm=True,
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@@ -0,0 +1,41 @@
import torch
from diffusers import (
UNet2DConditionModel,
)
from torch.fx.experimental.proxy_tensor import make_fx
class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
if __name__ == "__main__":
hf_model_name = "CompVis/stable-diffusion-v1-4"
unet = UnetModel(hf_model_name)
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)
fx_g = make_fx(
unet,
decomposition_table={},
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
_allow_fake_constant=False,
)(*inputs)
print(fx_g)

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -0,0 +1,45 @@
import requests
from PIL import Image
import base64
from io import BytesIO
import json
def llm_chat_test(verbose=False):
# Define values here
prompt = "What is the significance of the number 42?"
url = "http://127.0.0.1:8080/v1/chat/completions"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"messages": [
{
"role": "",
"content": prompt,
}
],
"device": "vulkan://0",
"max_tokens": 4096,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
res_dict = json.loads(res.content.decode("utf-8"))
print(f"[chat] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res_dict['choices'][0]['message']['content']}\n")
if __name__ == "__main__":
# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
llm_chat_test(verbose=True)

View File

@@ -0,0 +1,20 @@
from apps.shark_studio.modules.ckpt_processing import save_irpa
import argparse
import safetensors
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
default="",
help="input safetensors/irpa",
)
parser.add_argument(
"--prefix",
type=str,
default="",
help="prefix to add to all the keys in the irpa",
)
args = parser.parse_args()
output_file = save_irpa(args.input, args.prefix)
print("saved irpa to", output_file, "with prefix", args.prefix)

View File

@@ -0,0 +1,220 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import threading
import collections
import gradio as gr
from PIL import Image, PngImagePlugin
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
headers = {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes,
format="PNG",
pnginfo=(metadata if use_metadata else None),
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
class FIFOLock(object):
def __init__(self):
self._lock = threading.Lock()
self._inner_lock = threading.Lock()
self._pending_threads = collections.deque()
def acquire(self, blocking=True):
with self._inner_lock:
lock_acquired = self._lock.acquire(False)
if lock_acquired:
return True
elif not blocking:
return False
release_event = threading.Event()
self._pending_threads.append(release_event)
release_event.wait()
return self._lock.acquire()
def release(self):
with self._inner_lock:
if self._pending_threads:
release_event = self._pending_threads.popleft()
release_event.set()
self._lock.release()
__enter__ = acquire
def __exit__(self, t, v, tb):
self.release()
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get("path", "err")
if cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
)
)
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
print(message)
raise (e)
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class ApiCompat:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(
self.app,
host=server_name,
port=port,
root_path=root_path,
)
# def kill_studio(self):
# restart.stop_program()
# def restart_studio(self):
# if restart.is_restartable():
# restart.restart_program()
# return Response(status_code=501)
# def preprocess(self, args: dict):
# try:
# studio.state.begin(job="preprocess")
# preprocess(**args)
# studio.state.end()
# return models.PreprocessResponse(info="preprocess complete")
# except:
# studio.state.end()
# def stop_studio(request):
# studio.state.server_command = "stop"
# return Response("Stopping.")

View File

@@ -0,0 +1,115 @@
import base64
from fastapi import FastAPI
from io import BytesIO
from PIL import Image
from pydantic import BaseModel, Field
from fastapi.exceptions import HTTPException
from apps.shark_studio.api.sd import shark_sd_fn
sdapi = FastAPI()
class GenerationInputData(BaseModel):
prompt: list = [""]
negative_prompt: list = [""]
hf_model_id: str | None = None
height: int = Field(default=512, ge=128, le=1024, multiple_of=8)
width: int = Field(default=512, ge=128, le=1024, multiple_of=8)
sampler_name: str = "EulerDiscrete"
cfg_scale: float = Field(default=7.5, ge=1)
steps: int = Field(default=20, ge=1, le=100)
seed: int = Field(default=-1)
n_iter: int = Field(default=1)
config: dict = None
class GenerationResponseData(BaseModel):
images: list[str] = Field(description="Generated images, Base64 encoded")
properties: dict = {}
info: str
def encode_pil_to_base64(images: list[Image.Image]):
encoded_imgs = []
for image in images:
with BytesIO() as output_bytes:
image.save(output_bytes, format="PNG")
bytes_data = output_bytes.getvalue()
encoded_imgs.append(base64.b64encode(bytes_data))
return encoded_imgs
def decode_base64_to_image(encoding: str):
if encoding.startswith("data:image/"):
encoding = encoding.split(";", 1)[1].split(",", 1)[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
print(err)
raise HTTPException(status_code=400, detail="Invalid encoded image")
@sdapi.post(
"/v1/txt2img",
summary="Does text to image generation",
response_model=GenerationResponseData,
)
def txt2img_api(InputData: GenerationInputData):
model_id = (
InputData.hf_model_id or "stabilityai/stable-diffusion-3-medium-diffusers"
)
scheduler = "FlowEulerDiscrete"
print(
f"Prompt: {InputData.prompt}, "
f"Negative Prompt: {InputData.negative_prompt}, "
f"Seed: {InputData.seed},"
f"Model: {model_id}, "
f"Scheduler: {scheduler}. "
)
if not getattr(InputData, "config"):
InputData.config = {
"precision": "fp16",
"device": "rocm",
"target_triple": "gfx1150",
}
res = shark_sd_fn(
InputData.prompt,
InputData.negative_prompt,
None,
InputData.height,
InputData.width,
InputData.steps,
None,
InputData.cfg_scale,
InputData.seed,
custom_vae=None,
batch_count=InputData.n_iter,
batch_size=1,
scheduler=scheduler,
base_model_id=model_id,
custom_weights=None,
precision=InputData.config["precision"],
device=InputData.config["device"],
target_triple=InputData.config["target_triple"],
output_type="pil",
ondemand=False,
compiled_pipeline=False,
resample_type=None,
controlnets=[],
embeddings=[],
)
# Since we're not streaming we just want the last generator result
for items_so_far in res:
items = items_so_far
return {
"images": encode_pil_to_base64(items[0]),
"parameters": {},
"info": items[1],
}

View File

@@ -0,0 +1,226 @@
from multiprocessing import Process, freeze_support
freeze_support()
from PIL import Image
import os
import time
import sys
import logging
import apps.shark_studio.api.initializers as initialize
from apps.shark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
initialize.imports()
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
return api
def api_only():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.api.sd import sdapi
from fastapi import FastAPI
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
app.mount("/sdapi/", sdapi)
api = create_api(app)
# from modules import script_callbacks
# script_callbacks.before_ui_callback()
# script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0",
port=cmd_opts.server_port,
root_path="",
)
def launch_webui(address):
from tkinter import Tk
import webview
import gradio as gr
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdicon_loc,
amdlogo_loc,
)
launch_api = cmd_opts.api
initialize.initialize()
# from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# api.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
# api.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# api.include_router(APIRouter())
#
# # deal with CORS requests if CORS accept origins are set
# if args.api_accept_origin:
# print(
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
# )
# api.add_middleware(
# CORSMiddleware,
# allow_origins=args.api_accept_origin,
# allow_methods=["GET", "POST"],
# allow_headers=["*"],
# )
# else:
# print("API not configured for CORS")
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
import gradio as gr
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
# from apps.shark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(
value=amd_logo,
show_label=False,
interactive=False,
elem_id="tab_bar_logo",
show_download_button=False,
)
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Stable Diffusion", id=0):
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
# with gr.TabItem(label="Chat Bot", id=2):
# chat_element.render()
studio_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
studio_web.launch(
share=cmd_opts.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=cmd_opts.server_port,
favicon_path=amdicon_loc,
)
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
api_only()
# if cmd_opts.webui == False:
# api_only()
# else:
# webui()

View File

@@ -0,0 +1,239 @@
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def append_bot_prompt(history, input_prompt):
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
history += user_prompt
return history
language_model = None
def get_default_config():
return False
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
cli=False,
):
global language_model
if streaming_llm and prompt_prefix == "Clear":
language_model = None
return "Clearing history...", ""
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
language_model = LanguageModel(
model,
device=device,
precision=precision,
external_weights="safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
token_count = 0
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, exec_time in language_model.chat(history):
history[-1][-1] = f"{text}{E_SYS}"
if is_first:
prefill_time = exec_time
is_first = False
yield history, f"Prefill: {prefill_time:.2f}"
else:
total_time += exec_time
token_count += 1
tokens_per_sec = token_count / total_time
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = global_obj.get_device_list()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="fp32",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=False,
interactive=True,
visible=False,
)
streaming_llm = gr.Checkbox(
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=True,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(
fn=chat_fn,
inputs=[
clear,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
).then(lambda: None, None, [chatbot], queue=False)

View File

@@ -0,0 +1,67 @@
from apps.shark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_files):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = (
'<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
)
output = []
for lora_file in lora_files:
if lora_file == "":
output.extend(["<div><i>No LoRA selected</i></div>"])
elif not lora_file.lower().endswith(".safetensors"):
output.extend(
[
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
)
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
output.extend(
[
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
)
elif metadata is None:
output.extend(
[
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
)
else:
output.extend(
[
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
)
return output

View File

@@ -0,0 +1,373 @@
/*
Apply Gradio dark theme to the default Gradio theme.
Procedure to upgrade the dark theme:
- Using your browser, visit http://localhost:8080/?__theme=dark
- Open your browser inspector, search for the .dark css class
- Copy .dark class declarations, apply them here into :root
*/
:root {
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices, but see below */
@media (min-width: 1536px)
{
.gradio-container {
max-width: var(--size-full) !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define a class which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gradio-container-size-full {
max-width: var(--size-full) !important;
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}
#gallery + div {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
/* fix width and height of gallery items when on very large desktop screens, but see below */
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define classes which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
min-height: calc(768px + 4px + var(--size-14)) !important;
max-height: calc(768px + 4px + var(--size-14)) !important;
}
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
display: none;
}
/* Hide selected items from ui dropdowns */
#custom_model .options .item .inner-item,
#scheduler .options .item .inner-item,
#device .options .item .inner-item,
#stencil_model .options .item .inner-item {
display:none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs);
}
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
}
.outputgallery_sendto {
min-width: 7em !important;
}
/* output gallery should take up most of the viewport height regardless of image size/number */
#outputgallery_gallery .fixed-height {
min-height: 89vh !important;
}
.sd-right-panel {
height: calc(100vmin - var(--size-32) - var(--size-10)) !important;
overflow-y: scroll;
}
.sd-right-panel .fill {
flex: 1;
}
/* don't stretch non-square images to be square, breaking their aspect ratio */
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
object-fit: contain !important;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;
width: 100%;
}
#top_logo.logo_centered img {
object-fit: scale-down;
position: absolute;
width: 80%;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
#tab_bar_logo {
overflow: visible !important;
border-width: 0 !important;
height: 0px !important;
padding: 0;
margin: 0;
}
#tab_bar_logo .image-container {
object-fit: scale-down;
position: absolute !important;
top: 10px;
right: 0px;
height: 36px;
}

View File

@@ -0,0 +1,49 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -0,0 +1,406 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.web.ui.utils import amdlogo_loc
from apps.shark_studio.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=cmd_opts.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[path for path in relative_paths if (not path.isnumeric()) and path != "."]
)
return result_paths
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_element:
amd_logo = Image.open(amdlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=amd_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
show_download_button=False,
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=4,
)
with gr.Column(scale=4):
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_sd = gr.Button(
value="Stable Diffusion",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
# --- Event handlers
def on_clear_gallery():
return [
gr.Gallery(
value=[],
visible=False,
),
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_sd,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
)

View File

@@ -0,0 +1,866 @@
import os
import json
import gradio as gr
import numpy as np
from inspect import signature
from PIL import Image
from pathlib import Path
from datetime import datetime as dt
from gradio.components.image_editor import (
EditorValue,
)
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
get_configs,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdlogo_loc,
none_to_str_none,
str_none_to_none,
)
from apps.shark_studio.web.utils.state import (
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
# Disabled some models for demo purposes
sd_default_models = [
# "runwayml/stable-diffusion-v1-5",
# "stabilityai/stable-diffusion-2-1-base",
# "stabilityai/stable-diffusion-2-1",
# "stabilityai/stable-diffusion-xl-base-1.0",
# "stabilityai/sdxl-turbo",
]
sd_default_models.extend(get_checkpoints(model_type="scripts"))
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def submit_to_cnet_config(
stencil: str,
preprocessed_hint: str,
cnet_strength: int,
control_mode: str,
curr_config: dict,
):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
curr_config["controlnets"]["control_mode"] = control_mode
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
return curr_config
cnet_map = {}
cnet_map["controlnets"] = {
"control_mode": control_mode,
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
}
return cnet_map
def update_embeddings_json(embedding):
return {"embeddings": [embedding]}
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, "", {}]:
return input_cfg
for base_key in input_cfg:
main_cfg[base_key] = input_cfg[base_key]
return main_cfg
def pull_sd_configs(
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
):
sd_args = str_none_to_none(locals())
sd_cfg = {}
for arg in sd_args:
if arg in [
"prompt",
"negative_prompt",
"sd_init_image",
]:
sd_cfg[arg] = [sd_args[arg]]
elif arg in ["controlnets", "embeddings"]:
if isinstance(arg, dict):
sd_cfg[arg] = json.loads(sd_args[arg])
else:
sd_cfg[arg] = {}
else:
sd_cfg[arg] = sd_args[arg]
return json.dumps(sd_cfg)
def load_sd_cfg(sd_json: dict, load_sd_config: str):
if os.path.exists(load_sd_config):
config = load_sd_config
elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)):
config = os.path.join(get_configs_path(), load_sd_config)
else:
print(
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
)
config = sd_json
new_sd_config = none_to_str_none(json.loads(view_json_file(config)))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
for i in sd_json["sd_init_image"]:
if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
if not sd_json["device"]:
sd_json["device"] = gr.update()
return [
sd_json["prompt"][0],
sd_json["negative_prompt"][0],
sd_image,
sd_json["height"],
sd_json["width"],
gr.update(),
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json,
]
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath = save_name
elif cmd_opts.configs_path:
filepath = os.path.join(cmd_opts.configs_path, save_name)
else:
filepath = os.path.join(get_configs_path(), save_name)
if ".json" not in filepath:
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return save_name
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def import_original(original_img, width, height):
if original_img is None:
resized_img = create_canvas(width, height)
return resized_img
else:
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
if "turbo" in base_model_id:
new_steps = gr.Dropdown(
value=2,
choices=[1, 2],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
if "stable-diffusion-xl-base-1.0" in base_model_id:
new_steps = gr.Dropdown(
value=40,
choices=[20, 25, 30, 35, 40, 45, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
elif ".py" in base_model_id:
new_steps = gr.Dropdown(
value=20,
choices=[10, 15, 20],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
else:
new_steps = gr.Dropdown(
value=20,
choices=[10, 20, 30, 40, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
return [
gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
),
new_steps,
]
init_config = global_obj.get_init_config()
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=init_config["prompt"][0],
lines=4,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=init_config["negative_prompt"][0],
lines=4,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
):
with gr.Accordion(label="Device Settings", open=False):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=(
init_config["device"]
if init_config["device"]
else "rocm"
),
choices=global_obj.get_device_list(),
allow_custom_value=True,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value=init_config["target_triple"],
)
with gr.Row():
ondemand = gr.Checkbox(
value=init_config["ondemand"],
label="Low VRAM",
interactive=True,
visible=False,
)
precision = gr.Radio(
label="Precision",
value=init_config["precision"],
choices=[
"fp16",
"fp32",
],
visible=False,
)
with gr.Row():
height = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
visible=False, # DEMO
)
width = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
visible=False, # DEMO
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing",
open=False,
visible=False,
):
strength = gr.Slider(
0,
1,
value=init_config["strength"],
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=init_config["resample_type"],
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Row():
sd_model_info = (
f"Checkpoint Path: {str(get_checkpoints_path())}"
)
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value=init_config["base_model_id"],
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=init_config["seed"],
label="\U0001F331\U0000FE0F Seed",
info="An integer, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value=init_config["scheduler"],
choices=scheduler_model_map.keys(),
allow_custom_value=False,
visible=False,
)
with gr.Row():
steps = gr.Dropdown(
value=20,
choices=[10, 15, 20],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
5, # DEMO
value=4,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
visible=False,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=init_config["batch_count"],
step=1,
label="Batch Count",
interactive=True,
visible=False,
)
batch_size = gr.Slider(
1,
4,
value=init_config["batch_size"],
step=1,
label="Batch Size",
interactive=False, # DEMO
visible=False,
)
compiled_pipeline = gr.Checkbox(
value=init_config["compiled_pipeline"],
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row(elem_classes=["fill"], visible=False):
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
write_default_sd_configs(get_configs_path())
default_config_file = global_obj.get_init_config()
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row(visible=False):
with gr.Row():
load_sd_config = gr.Dropdown(
label="Load Config",
value=cmd_opts.defaults,
choices=get_configs(),
allow_custom_value=True,
visible=False,
)
with gr.Row():
save_sd_config = gr.Button(value="Save Config", size="sm")
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
# with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights",
open=False,
visible=False, # DEMO
):
with gr.Column():
custom_weights = gr.Dropdown(
label="Checkpoint Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value=init_config["custom_weights"],
allow_custom_value=True,
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
sd_vae_info = f"VAE Path: {sd_vae_info}"
custom_vae = gr.Dropdown(
label=f"VAE Model",
info=sd_vae_info,
elem_id="custom_model",
value=init_config["custom_vae"],
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
"\\", "\n\\"
)
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=(
init_config["embeddings"][0]
if (len(init_config["embeddings"].keys()) > 1)
else "None"
),
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
)
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
embeddings_config = gr.JSON(
label="Embeddings Options", min_width=50, scale=1
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
show_progress=False,
).then(
fn=update_embeddings_json,
inputs=[lora_opt],
outputs=[embeddings_config],
show_progress=False,
)
with gr.Accordion(
label="Controlnet Options",
open=False,
visible=False,
):
preprocessed_hints = gr.State([])
with gr.Column():
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
with gr.Row():
cnet_config = gr.JSON()
with gr.Column():
clear_config = gr.ClearButton(
value="Clear Controlnet Config",
size="sm",
components=cnet_config,
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Row():
with gr.Column(scale=1):
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
cnet_strength = gr.Slider(
label="Controlnet Strength",
minimum=0,
maximum=100,
value=50,
step=1,
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=512,
maximum=1024,
value=512,
step=512,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=512,
maximum=1024,
value=512,
step=512,
)
make_canvas = gr.Button(
value="Make Canvas!",
)
use_input_img = gr.Button(
value="Use Original Image",
size="sm",
)
cnet_input = gr.Image(
value=None,
type="pil",
image_mode="RGB",
interactive=True,
)
with gr.Column(scale=1):
cnet_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=False,
show_label=True,
)
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
use_result = gr.Button(
"Submit",
size="sm",
)
make_canvas.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[cnet_input],
queue=False,
)
cnet_gen.click(
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
],
outputs=[
cnet_output,
preprocessed_hints,
],
)
use_result.click(
fn=submit_to_cnet_config,
inputs=[
cnet_model,
cnet_output,
cnet_strength,
control_mode,
cnet_config,
],
outputs=[
cnet_config,
],
queue=False,
)
with gr.Column(scale=3, min_width=600):
with gr.Tabs() as sd_tabs:
sd_element.load(
# Workaround for Gradio issue #7085
# TODO: revert to setting selected= in gr.Tabs declaration
# once this is resolved in Gradio
lambda: gr.Tabs(selected=101),
outputs=[sd_tabs],
)
with gr.Tab(
label="Input Image", id=100, visible=False
) as sd_tab_init_image: # DEMO
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
# TODO: make this import image prompt info if it exists
sd_init_image = gr.Image(
type="pil",
interactive=True,
show_label=False,
)
use_input_img.click(
fn=import_original,
inputs=[
sd_init_image,
canvas_width,
canvas_height,
],
outputs=[cnet_input],
queue=False,
)
with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
sd_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
object_fit="fit",
preview=True,
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop", visible=False)
# with gr.Tab(label="Config", id=102) as sd_tab_config:
# with gr.Group():#elem_classes=["sd-right-panel"]):
# with gr.Row(elem_classes=["fill"], visible=False):
# Path(get_configs_path()).mkdir(
# parents=True, exist_ok=True
# )
# write_default_sd_configs(get_configs_path())
# default_config_file = global_obj.get_init_config()
# sd_json = gr.JSON(
# elem_classes=["fill"],
# value=view_json_file(default_config_file),
# )
# with gr.Row():
# with gr.Row():
# load_sd_config = gr.Dropdown(
# label="Load Config",
# value=cmd_opts.defaults,
# choices=get_configs(),
# allow_custom_value=True,
# )
# with gr.Row():
# save_sd_config = gr.Button(
# value="Save Config", size="sm"
# )
# clear_sd_config = gr.ClearButton(
# value="Clear Config",
# size="sm",
# components=sd_json,
# )
# # with gr.Row():
# sd_config_name = gr.Textbox(
# value="Config Name",
# info="Name of the file this config will be saved to.",
# interactive=True,
# show_label=False,
# )
with gr.Tab(label="Log", id=103, visible=False) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights, steps],
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
],
outputs=[
sd_json,
],
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=sd_status,
)
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,
sd_status,
],
)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
generate_click = (
stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,43 @@
from enum import IntEnum
import math
import sys
import os
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
amdlogo_loc = resource_path("logos/amd-logo.jpg")
amdicon_loc = resource_path("logos/amd-icon.jpg")
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def none_to_str_none(props: dict):
for key in props:
props[key] = "None" if props[key] == None else props[key]
return props
def str_none_to_none(props: dict):
for key in props:
props[key] = None if props[key] == "None" else props[key]
return props

View File

@@ -0,0 +1,12 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,95 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
default_sd_configs = {
# "default_sd_config.json": sdxl_turbo,
# "sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -0,0 +1,115 @@
import os
import sys
import glob
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
if not os.path.exists(config_fpath):
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def safe_name(name):
return name.split("/")[-1].replace("-", "_")
def get_path_stem(path):
path = Path(path)
return path.stem
def get_resource_path(path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
if os.path.isabs(path):
return path
else:
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, path)).resolve(strict=False)
return result
def get_configs_path() -> Path:
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(configs)
def get_generated_imgs_path() -> Path:
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)
def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "scripts":
file_types = ["shark_*.py"]
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_configs():
return sorted(
[
os.path.basename(x)
for x in glob.glob(os.path.join(get_configs_path(), "*.json"))
],
key=str.casefold,
)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

View File

@@ -0,0 +1,158 @@
import gc
from ...api.utils import get_available_devices
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import os
from apps.shark_studio.web.utils.file_utils import get_configs_path
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def _init():
global _sd_obj
global _llm_obj
global _devices
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_llm_obj = None
_devices = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None
set_devices()
def set_sd_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = None
_sd_obj = value
def set_llm_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = value
_sd_obj = None
def set_devices():
global _devices
_devices = get_available_devices()
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_pipe_kwargs(value):
global _pipe_kwargs
_pipe_kwargs = value
def set_prep_kwargs(value):
global _prep_kwargs
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_llm_obj():
global _llm_obj
return _llm_obj
def get_device_list():
global _devices
return _devices
def get_init_config():
global _init_config
if os.path.exists(cmd_opts.defaults):
_init_config = cmd_opts.defaults
elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)):
_init_config = os.path.join(get_configs_path(), cmd_opts.defaults)
else:
print(
"Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config."
)
_init_config = os.path.join(get_configs_path(), "sdxl-turbo.json")
return _init_config
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_pipe_kwargs():
global _pipe_kwargs
return _pipe_kwargs
def get_prep_kwargs():
global _prep_kwargs
return _prep_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _llm_obj
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _llm_obj
del _schedulers
gc.collect()
_sd_obj = None
_llm_obj = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None

View File

@@ -0,0 +1,6 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -0,0 +1,43 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -0,0 +1,53 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -0,0 +1,52 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -0,0 +1,139 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"):
result.pop("LoRA")
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_CURRENT
return {
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -0,0 +1,216 @@
import re
from pathlib import Path
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in sd_model_map:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s" % model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_model_map:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -0,0 +1,39 @@
import apps.shark_studio.web.utils.globals as global_obj
import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"
def get_generation_text_info(seeds, device):
cfg_dump = {}
for cfg in global_obj.get_config_dict():
cfg_dump[cfg] = cfg
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
f"\nsteps={cfg_dump['steps']}, "
f"guidance_scale={cfg_dump['guidance_scale']}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
if not cfg_dump.use_hiresfix
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
)
text_output += (
f"batch_count={cfg_dump['batch_count']}, "
f"batch_size={cfg_dump['batch_size']}, "
f"max_length={cfg_dump['max_length']}"
)
return text_output

View File

@@ -0,0 +1,75 @@
import os
import shutil
from time import time
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():
cleanup_start = time()
print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(os.path.join(shark_tmp, filename))
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -1,87 +0,0 @@
Compile / Run Instructions:
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=mhlo for tf models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```
Run / Benchmark Command (FP32 - NCHW):
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
```shell
## Vulkan AMD:
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -1 +0,0 @@
from apps.stable_diffusion.scripts.txt2img import txt2img_inf

View File

@@ -1,240 +0,0 @@
import logging
import os
from models.stable_diffusion.main import stable_diff_inf
from models.stable_diffusion.utils import get_available_devices
from dotenv import load_dotenv
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram import BotCommand
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
from io import BytesIO
import random
log = logging.getLogger("TG.Bot")
logging.basicConfig()
log.warning("Start")
load_dotenv()
os.environ["AMD_ENABLE_LLPC"] = "0"
TG_TOKEN = os.getenv("TG_TOKEN")
SELECTED_MODEL = "stablediffusion"
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
STEPS = 30
NEGATIVE_PROMPT = (
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
" legs, fused fingers, too many fingers"
)
GUIDANCE_SCALE = 6
available_devices = get_available_devices()
models_list = [
"stablediffusion",
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]
sheds_list = [
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
]
def image_to_bytes(image):
bio = BytesIO()
bio.name = "image.jpeg"
image.save(bio, "JPEG")
bio.seek(0)
return bio
def get_try_again_markup():
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
reply_markup = InlineKeyboardMarkup(keyboard)
return reply_markup
def generate_image(prompt):
seed = random.randint(1, 10000)
log.warning(SELECTED_MODEL)
log.warning(STEPS)
image, text = stable_diff_inf(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
steps=STEPS,
guidance_scale=GUIDANCE_SCALE,
seed=seed,
scheduler_key=SELECTED_SCHEDULER,
variant=SELECTED_MODEL,
device_key=available_devices[0],
)
return image, seed
async def generate_and_send_photo(
update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
progress_msg = await update.message.reply_text(
"Generating image...", reply_to_message_id=update.message.message_id
)
im, seed = generate_image(prompt=update.message.text)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{update.message.text}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=update.message.message_id,
)
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
query = update.callback_query
if query.data in models_list:
global SELECTED_MODEL
SELECTED_MODEL = query.data
await query.answer()
await query.edit_message_text(text=f"Selected model: {query.data}")
return
if query.data in sheds_list:
global SELECTED_SCHEDULER
SELECTED_SCHEDULER = query.data
await query.answer()
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
return
replied_message = query.message.reply_to_message
await query.answer()
progress_msg = await query.message.reply_text(
"Generating image...", reply_to_message_id=replied_message.message_id
)
if query.data == "TRYAGAIN":
prompt = replied_message.text
im, seed = generate_image(prompt)
await context.bot.delete_message(
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
)
await context.bot.send_photo(
update.effective_user.id,
image_to_bytes(im),
caption=f'"{prompt}" (Seed: {seed})',
reply_markup=get_try_again_markup(),
reply_to_message_id=replied_message.message_id,
)
async def select_model_handler(update, context):
text = "Select model"
keyboard = []
for model in models_list:
keyboard.append(
[
InlineKeyboardButton(text=model, callback_data=model),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def select_scheduler_handler(update, context):
text = "Select schedule"
keyboard = []
for shed in sheds_list:
keyboard.append(
[
InlineKeyboardButton(text=shed, callback_data=shed),
]
)
markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(text=text, reply_markup=markup)
async def set_steps_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_steps ")[1]
global STEPS
STEPS = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_steps 30"
)
await update.message.reply_text(input_args)
async def set_negative_prompt_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_negative_prompt ")[1]
global NEGATIVE_PROMPT
NEGATIVE_PROMPT = input_args
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_negative_prompt ugly, bad art, mutated"
)
await update.message.reply_text(input_args)
async def set_guidance_scale_handler(update, context):
input_mex = update.message.text
log.warning(input_mex)
try:
input_args = input_mex.split("/set_guidance_scale ")[1]
global GUIDANCE_SCALE
GUIDANCE_SCALE = int(input_args)
except Exception:
input_args = (
"Invalid parameter for command. Correct command looks like\n"
" /set_guidance_scale 7"
)
await update.message.reply_text(input_args)
async def setup_bot_commands(application: Application) -> None:
await application.bot.set_my_commands(
[
BotCommand("select_model", "to select model"),
BotCommand("select_scheduler", "to select scheduler"),
BotCommand("set_steps", "to set steps"),
BotCommand("set_guidance_scale", "to set guidance scale"),
BotCommand("set_negative_prompt", "to set negative prompt"),
]
)
app = (
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
)
app.add_handler(CommandHandler("select_model", select_model_handler))
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
app.add_handler(CommandHandler("set_steps", set_steps_handler))
app.add_handler(
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
)
app.add_handler(
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
)
app.add_handler(
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
)
app.add_handler(CallbackQueryHandler(button))
log.warning("Start bot")
app.run_polling()

View File

@@ -1,331 +0,0 @@
import os
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
import sys
import json
import torch
import re
import time
from pathlib import Path
from PIL import PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = os.path.basename(args.ckpt_loc)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
txt2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def txt2img_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global txt2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
if not txt2img_obj:
sys.exit("text to image pipeline must not return a null value")
txt2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
for run in range(args.runs):
if run > 0:
seed = -1
seed = utils.sanitize_seed(seed)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --runs=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)

View File

@@ -1,79 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/css/*', 'css' ),
( 'web/logos/*', 'logos' )
]
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,77 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torchvision')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('diffusers')
datas += copy_metadata('transformers')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += collect_data_files('gradio')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
]
binaries = []
block_cipher = None
a = Analysis(
['scripts/txt2img.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=['shark', 'shark.*', 'shark.shark_inference', 'shark_inference', 'iree.tools.core', 'gradio', 'apps'],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_sd_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,8 +0,0 @@
from apps.stable_diffusion.src.utils import (
args,
set_init_device_flags,
prompt_examples,
get_available_devices,
)
from apps.stable_diffusion.src.pipelines import Text2ImagePipeline
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -1,11 +0,0 @@
from apps.stable_diffusion.src.models.model_wrappers import (
SharkifyStableDiffusionModel,
)
from apps.stable_diffusion.src.models.opt_params import (
get_vae,
get_unet,
get_clip,
get_tokenizer,
get_params,
get_variant_version,
)

View File

@@ -1,295 +0,0 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from collections import defaultdict
import torch
import traceback
import re
import sys
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
fetch_or_delete_vmfbs,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
)
# These shapes are parameter dependent.
def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape = []
for i in range(len(shape)):
if shape[i] == "max_len":
new_shape.append(max_len)
elif shape[i] == "height":
new_shape.append(height)
elif shape[i] == "width":
new_shape.append(width)
elif isinstance(shape[i], str):
if "batch_size" in shape[i]:
mul_val = int(shape[i].split("*")[0])
new_shape.append(batch_size * mul_val)
else:
new_shape.append(shape[i])
return new_shape
# Get the input info for various models i.e. "unet", "clip", "vae".
def get_input_info(model_info, max_len, width, height, batch_size):
dtype_config = {"f32": torch.float32, "i64": torch.int64}
input_map = defaultdict(list)
for k in model_info:
for inp in model_info[k]:
shape = model_info[k][inp]["shape"]
dtype = dtype_config[model_info[k][inp]["dtype"]]
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, max_len, width, height, batch_size
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
else:
tensor = torch.randn(*clean_shape).to(dtype)
elif isinstance(shape, int):
tensor = torch.tensor(shape).to(dtype)
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
class SharkifyStableDiffusionModel:
def __init__(
self,
model_id: str,
custom_weights: str,
precision: str,
max_len: int = 64,
width: int = 512,
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
self.model_name = (
str(batch_size)
+ "_"
+ str(max_len)
+ "_"
+ str(height)
+ "_"
+ str(width)
+ "_"
+ precision
)
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
# custom model.
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
self.model_name = self.model_name + "_" + model_name
def check_params(self, max_len, width, height):
if not (max_len >= 32 and max_len <= 77):
sys.exit("please specify max_len in the range [32, 77].")
if not (width % 8 == 0 and width >= 384):
sys.exit("width should be greater than 384 and multiple of 8")
if not (height % 8 == 0 and height >= 384):
sys.exit("height should be greater than 384 and multiple of 8")
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
)
self.base_vae = base_vae
def forward(self, input):
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()
vae = VaeModel()
inputs = tuple(self.inputs["vae"])
is_f16 = True if self.precision == "fp16" else False
vae_name = "base_vae" if self.base_vae else "vae"
shark_vae = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
model_name=vae_name + self.model_name,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
def get_unet(self):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
)
self.in_channels = self.unet.in_channels
self.train(False)
def forward(
self, latent, timestep, text_embedding, guidance_scale
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents, timestep, text_embedding, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel()
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False]
shark_unet = compile_through_fx(
unet,
inputs,
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
shark_clip = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
model_name="clip" + self.model_name,
extra_args=get_opt_flags("clip", precision="fp32"),
)
return shark_clip
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
vmfbs = fetch_or_delete_vmfbs(
self.model_name, self.base_vae, self.precision
)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights)
else:
model_to_run = args.hf_model_id
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
try:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id)
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -1,89 +0,0 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
def get_params(bucket_key, model_key, model, is_tuned, precision):
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags
def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "unet", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_vae():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
is_base = "/base" if args.use_base_vae else ""
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
else:
bucket_key = f"{variant}/{is_tuned}"
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "vae", is_tuned, args.precision
)
return get_shark_model(bucket, model_name, iree_flags)
def get_clip():
variant, version = get_variant_version(args.hf_model_id)
bucket_key = f"{variant}/untuned"
model_key = (
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
)
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, "clip", "untuned", "fp32"
)
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
)
return tokenizer

View File

@@ -1,3 +0,0 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)

View File

@@ -1,135 +0,0 @@
import torch
from tqdm.auto import tqdm
import numpy as np
from random import randint
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Text2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Get Image latents
latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -1,206 +0,0 @@
import torch
from transformers import CLIPTokenizer
from PIL import Image
from tqdm.auto import tqdm
import time
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
get_clip,
get_unet,
get_tokenizer,
)
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
)
class StableDiffusionPipeline:
def __init__(
self,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
text_input = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# Get unconditional embeddings as well
uncond_input = self.tokenizer(
neg_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])
clip_inf_start = time.time()
text_embeddings = self.text_encoder("forward", (text_input,))
clip_inf_time = (time.time() - clip_inf_start) * 1000
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
latents_numpy = latents
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.vae("forward", (latents_numpy,))
vae_inf_time = (time.time() - vae_start) * 1000
end_profiling(profile_device)
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
if use_base_vae:
images = torch.from_numpy(images)
images = (images.detach().cpu() * 255.0).numpy()
images = images.round()
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
def produce_img_latents(
self,
latents,
text_embeddings,
guidance_scale,
total_timesteps,
dtype,
cpu_scheduling,
return_all_latents=False,
):
step_time_sum = 0
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
latent_model_input = self.scheduler.scale_model_input(latents, t)
if cpu_scheduling:
latent_model_input = latent_model_input.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
latents = self.scheduler.step(
noise_pred, t, latents
).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents)
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
if not return_all_latents:
return latents
all_latents = torch.cat(latent_history, dim=0)
return all_latents
@classmethod
def from_pretrained(
cls,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
import_mlir: bool,
model_id: str,
ckpt_loc: str,
precision: str,
max_length: int,
batch_size: int,
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
):
if import_mlir:
# TODO: Delet this when on-the-fly tuning of models work.
use_tuned = False
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
precision,
max_len=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
return cls(
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
)

View File

@@ -1,4 +0,0 @@
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)

View File

@@ -1,51 +0,0 @@
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
def get_schedulers(model_id):
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"DPMSolverMultistep"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"SharkEulerDiscrete"
] = SharkEulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["SharkEulerDiscrete"].compile()
return schedulers

View File

@@ -1,143 +0,0 @@
import sys
import numpy as np
from typing import List, Optional, Tuple, Union
from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from diffusers.configuration_utils import register_to_config
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_shark_model,
args,
)
import torch
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
super().__init__(
num_train_timesteps,
beta_start,
beta_end,
beta_schedule,
trained_betas,
prediction_type,
)
def compile(self):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
BATCH_SIZE = args.batch_size
model_input = {
"euler": {
"latent": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"output": torch.randn(
BATCH_SIZE, 4, args.height // 8, args.width // 8
),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
}
example_latent = model_input["euler"]["latent"]
example_output = model_input["euler"]["output"]
if args.precision == "fp16":
example_latent = example_latent.half()
example_output = example_output.half()
example_sigma = model_input["euler"]["sigma"]
example_dt = model_input["euler"]["dt"]
class ScalingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, latent, sigma):
return latent / ((sigma**2 + 1) ** 0.5)
class SchedulerStepModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, noise_pred, sigma, latent, dt):
pred_original_sample = latent - sigma * noise_pred
derivative = (latent - pred_original_sample) / sigma
return latent + derivative * dt
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.import_mlir:
scaling_model = ScalingModel()
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
step_model = SchedulerStepModel()
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else:
self.scaling_model = get_shark_model(
SCHEDULER_BUCKET,
"euler_scale_model_input_" + args.precision,
iree_flags,
)
self.step_model = get_shark_model(
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model(
"forward",
(
sample,
sigma,
),
send_to_host=False,
)
def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model(
"forward",
(
noise_pred,
sigma,
latent,
dt,
),
send_to_host=False,
)

View File

@@ -1,27 +0,0 @@
from apps.stable_diffusion.src.utils.profiler import (
start_profiling,
end_profiling,
)
from apps.stable_diffusion.src.utils.resources import (
prompt_examples,
models_db,
base_models,
opt_flags,
resource_path,
)
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import (
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
map_device_to_name_path,
set_init_device_flags,
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
)

View File

@@ -1,18 +0,0 @@
from apps.stable_diffusion.src.utils.stable_args import args
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
import iree
print(f"Profiling and saving to {file_path}.")
vulkan_device = iree.runtime.get_device(args.device)
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
return vulkan_device
return None
def end_profiling(device):
if device:
return device.end_profiling()

View File

@@ -1,37 +0,0 @@
import os
import json
import sys
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
def get_json_file(path):
json_var = []
loc_json = resource_path(path)
if os.path.exists(loc_json):
with open(loc_json, encoding="utf-8") as fopen:
json_var = json.load(fopen)
if not json_var:
print(f"Unable to fetch {path}")
return json_var
# TODO: This shouldn't be called from here, every time the file imports
# it will run all the global vars.
prompt_examples = get_json_file("resources/prompts.json")
models_db = get_json_file("resources/model_db.json")
# The base_model contains the input configuration for the different
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -1,98 +0,0 @@
{
"stabilityai/stable-diffusion-2-1": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
1024
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
},
"CompVis/stable-diffusion-v1-4": {
"unet": {
"latents": {
"shape": [
"1*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"embedding": {
"shape": [
"2*batch_size",
"max_len",
768
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 2,
"dtype": "f32"
}
},
"vae": {
"latents" : {
"shape" : [
"1*batch_size",4,"height","width"
],
"dtype":"f32"
}
},
"clip": {
"token" : {
"shape" : [
"2*batch_size",
"max_len"
],
"dtype":"i64"
}
}
}
}

View File

@@ -1,21 +0,0 @@
[
{
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
"openjourney/v1_4":"prompthero/openjourney",
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
},
{
"stablediffusion/fp16":"fp16",
"stablediffusion/fp32":"main",
"anythingv3/fp16":"diffusers",
"anythingv3/fp32":"diffusers",
"analogdiffusion/fp16":"main",
"analogdiffusion/fp32":"main",
"openjourney/fp16":"main",
"openjourney/fp32":"main"
}
]

View File

@@ -1,82 +0,0 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
"anythingv3/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"analogdiffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
"stablediffusion/v1_4/unet/fp16/length_77/tuned":"unet_8dec_fp16_tuned",
"stablediffusion/v1_4/unet/fp16/length_77/tuned/cuda":"unet_8dec_fp16_cuda_tuned",
"stablediffusion/v1_4/unet/fp32/length_77/untuned":"unet_1dec_fp32",
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_19dec_fp16",
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]

View File

@@ -1,84 +0,0 @@
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
}
}

View File

@@ -1,8 +0,0 @@
[["A high tech solarpunk utopia in the Amazon rainforest"],
["A pikachu fine dining with a view to the Eiffel Tower"],
["A mecha robot in a favela in expressionist style"],
["an insect robot preparing a delicious meal"],
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]

View File

@@ -1,234 +0,0 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from apps.stable_diffusion.src.utils.stable_args import args
def get_device():
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
return device
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from apps.stable_diffusion.src.models import (
get_params,
get_variant_version,
)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
device = get_device()
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
variant, version = get_variant_version(args.hf_model_id)
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
device = get_device()
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=True,
)
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode, out_file_path
def dump_after_mlir(input_mlir, model_name, use_winograd):
if use_winograd:
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))' "
)
else:
dump_after = "iree-preprocessing-pad-linalg-ops"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
)
device_spec_args = ""
device = get_device()
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
elif device == "vulkan":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
f"{preprocess_flag}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_after_mlir(input_mlir, model_name, use_winograd)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
config_path=lowering_config_dir,
search_op="all",
)
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return bytecode, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
device = get_device()
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
if model_from_tank:
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
else:
# Just use this function to convert bytecode to string
orig_model, model_path = annotate_with_winograd(
mlir_model, "", model_name
)
mlir_model = model_path
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
sd_model_annotation(mlir_model, model_name, model_from_tank=True)

View File

@@ -1,345 +0,0 @@
import argparse
from pathlib import Path
def path_expand(s):
return Path(s).expanduser().resolve()
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Stable Diffusion Params
##############################################################################
p.add_argument(
"-p",
"--prompts",
action="append",
default=[],
help="text of which images to be generated.",
)
p.add_argument(
"--negative_prompts",
nargs="+",
default=[""],
help="text you don't want to see in the generated image.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="the seed to use.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="the number of inferences to be made in a single `run`.",
)
p.add_argument(
"--height",
type=int,
default=512,
help="the height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
help="the width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="the value to be used for guidance scaling.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="max length of the tokenizer output, options are 64 and 77.",
)
##############################################################################
### Model Config and Usage Params
##############################################################################
p.add_argument(
"--device", type=str, default="vulkan", help="device to run the model."
)
p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
p.add_argument(
"--load_vmfb",
default=True,
action=argparse.BooleanOptionalAction,
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
)
p.add_argument(
"--save_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="specify the format in which output image is save. Supported options: jpg / png",
)
p.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory path to save the output images and json",
)
p.add_argument(
"--runs",
type=int,
default=1,
help="number of images to be generated with random seeds in single execution",
)
p.add_argument(
"--ckpt_loc",
type=str,
default="",
help="Path to SD's .ckpt file.",
)
p.add_argument(
"--hf_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
p.add_argument(
"--vulkan_debug_utils",
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--vulkan_large_heap_block_size",
default="4147483648",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)
p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for disabling vulkan validation layers when benchmarking",
)
##############################################################################
### Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="use the default scheduler precompiled into the model if available",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="flag setting warmup count for clip and vae [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save a generation information json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
)
##############################################################################
### Web UI flags
##############################################################################
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
### SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
args, unknown = p.parse_known_args()

View File

@@ -1,460 +0,0 @@
import os
import gc
import json
from pathlib import Path
import numpy as np
from random import randint
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys, functools, operator
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
def get_vmfb_path_name(model_name):
device = (
args.device
if "://" not in args.device
else "-".join(args.device.split("://"))
)
extended_name = "{}_{}".format(model_name, device)
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
return [vmfb_path, extended_name]
def _compile_module(shark_module, model_name, extra_args=[]):
if args.load_vmfb or args.save_vmfb:
[vmfb_path, extended_name] = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
if args.save_vmfb:
print("Saving to {}".format(vmfb_path))
else:
print(
"No vmfb found. Compiling and saving to {}".format(
vmfb_path
)
)
path = shark_module.save_module(
os.getcwd(), extended_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=tank_url,
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into a shark_module.
def compile_through_fx(
model,
inputs,
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
):
from shark.parser import shark_args
if "cuda" in args.device:
shark_args.enable_tf32 = True
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
)
if use_tuned:
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def set_init_device_flags():
if "vulkan" in args.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
args.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
)
elif "cuda" in args.device:
args.device = "cuda"
elif "cpu" in args.device:
args.device = "cpu"
# set max_length based on availability.
if args.hf_model_id in [
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]:
args.max_length = 77
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.hf_model_id == "prompthero/openjourney"
or args.ckpt_loc != ""
or args.precision != "fp16"
or args.height != 512
or args.width != 512
or args.batch_size != 1
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
if args.use_tuned:
print(f"Using tuned models for {args.hf_model_id}/fp16/{args.device}.")
else:
print("Tuned models are currently not supported for this setting.")
# set import_mlir to True for unuploaded models.
if args.ckpt_loc != "":
args.import_mlir = True
elif args.hf_model_id not in [
"Linaqruf/anything-v3.0",
"dreamlike-art/dreamlike-diffusion-1.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.import_mlir = True
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
args.import_mlir = True
# Utility to get list of devices available.
def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
for i, device in enumerate(device_list_dict):
device_list.append(f"{device['name']} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("cpu")
return available_devices
def disk_space_check(path, lim=20):
from shutil import disk_usage
du = disk_usage(path)
free = du.free / (1024 * 1024 * 1024)
if free <= lim:
print(f"[WARNING] Only {free:.2f}GB space available in {path}.")
def get_opt_flags(model, precision="fp16"):
iree_flags = []
is_tuned = "tuned" if args.use_tuned else "untuned"
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
vmfb_path = [
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
]
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
compiled_models = [None] * 3
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
for i in range(len(vmfb_path)):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
return compiled_models
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
variants_path = os.path.join(os.getcwd(), "variants.json")
data = {model_to_run: base_model}
json_data = {}
if os.path.exists(variants_path):
with open(variants_path, "r", encoding="utf-8") as jsonFile:
json_data = json.load(jsonFile)
# Return with base_model's info if base_model is "".
if base_model == "":
if model_to_run in json_data:
base_model = json_data[model_to_run]
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed

View File

@@ -1,70 +0,0 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! Please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
* https://github.com/iree-org/iree
## Install this specific AMD Drivers (AMD latest may not have all the fixes).
### AMD KB Drivers for RDNA2 and RDNA3:
*AMD Software: Adrenalin Edition 22.11.1 for MLIR/IREE Driver Version 22.20.29.09 for Windows® 10 and Windows® 11 (Windows Driver Store Version 31.0.12029.9003)*
First, for RDNA2 users, download this special driver in a folder of your choice. We recommend you keep the installation files around, since you may need to re-install it later, if Windows Update decides to overwrite it:
https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree
For RDNA3, the latest driver 23.1.2 supports MLIR/IREE as well: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-1-2-kb
KNOWN ISSUES with this special AMD driver:
* `Windows Update` may (depending how it's configured) automatically install a new official AMD driver that overwrites this IREE-specific driver. If Stable Diffusion used to work, then a few days later, it slows down a lot or produces incorrect results (e.g. black images), this may be the cause. To fix this problem, please check the installed driver version, and re-install the special driver if needed. (TODO: document how to prevent this `Windows Update` behavior!)
* Some people using this special driver experience mouse pointer accuracy issues, especially if using a larger-than-default mouse pointer. The clicked point isn't centered properly. One possible work-around is to reset the pointer size to "1" in "Change pointer size and color".
## Installation
Download the latest Windows SHARK SD binary [492 here](https://github.com/nod-ai/SHARK/releases/download/20230203.492/shark_sd_20230203_492.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
* clear the `huggingface` cache. In Windows, this is `C:\Users\%username%\.cache\huggingface`.
## Running
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE to start the web browser)
* The first run may take about 10-15 minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* If successful, you will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/?__theme=dark.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment. The application should stop.
* Please make sure to do the above step before you attempt to update the EXE to a new version.
# Results
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
Here are some samples generated:
![tajmahal, snow, sunflowers, oil on canvas_0](https://user-images.githubusercontent.com/74956/204934186-141f7e43-6eb2-4e89-a99c-4704d20444b3.jpg)
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
The output on a 7900XTX would like:
```shell
Stats for run 0:
Average step time: 47.19188690185547ms/it
Clip Inference time (ms) = 109.531
VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.

View File

@@ -1,15 +0,0 @@
You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot)
Then create in the directory web file .env
In it the record:
TG_TOKEN="your_token"
specifying your bot's token from previous step.
Then run telegram_bot.py with the same parameters that you use when running index.py, for example:
python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP
Bot commands:
/select_model
/select_scheduler
/set_steps "integer number of steps"
/set_guidance_scale "integer number"
/set_negative_prompt "negative text"
Any other text triggers the creation of an image based on it.

View File

@@ -1,209 +0,0 @@
/* Overwrite the Gradio default theme with their .dark theme declarations */
:root {
--color-focus-primary: var(--color-grey-700);
--color-focus-secondary: var(--color-grey-600);
--color-focus-ring: rgb(55 65 81);
--color-background-primary: var(--color-grey-950);
--color-background-secondary: var(--color-grey-900);
--color-background-tertiary: var(--color-grey-800);
--color-text-body: var(--color-grey-100);
--color-text-label: var(--color-grey-200);
--color-text-placeholder: var(--color-grey);
--color-text-subdued: var(--color-grey-400);
--color-text-link-base: var(--color-blue-500);
--color-text-link-hover: var(--color-blue-400);
--color-text-link-visited: var(--color-blue-600);
--color-text-link-active: var(--color-blue-500);
--color-text-code-background: var(--color-grey-800);
--color-text-code-border: color.border-primary;
--color-border-primary: var(--color-grey-700);
--color-border-secondary: var(--color-grey-600);
--color-border-highlight: var(--color-accent-base);
--color-accent-base: var(--color-orange-500);
--color-accent-light: var(--color-orange-300);
--color-accent-dark: var(--color-orange-700);
--color-functional-error-base: var(--color-red-400);
--color-functional-error-subdued: var(--color-red-300);
--color-functional-error-background: var(--color-background-primary);
--color-functional-info-base: var(--color-yellow);
--color-functional-info-subdued: var(--color-yellow-300);
--color-functional-success-base: var(--color-green);
--color-functional-success-subdued: var(--color-green-300);
--shadow-spread: 2px;
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
--api-pill-background: var(--color-orange-400);
--api-pill-border: var(--color-orange-600);
--api-pill-text: var(--color-orange-900);
--block-border-color: var(--color-border-primary);
--block-background: var(--color-background-tertiary);
--uploadable-border-color-hover: var(--color-border-primary);
--uploadable-border-color-loaded: var(--color-functional-success);
--uploadable-text-color: var(--color-text-subdued);
--block_label-border-color: var(--color-border-primary);
--block_label-icon-color: var(--color-text-label);
--block_label-shadow: var(--shadow-drop);
--block_label-background: var(--color-background-secondary);
--icon_button-icon-color-base: var(--color-text-label);
--icon_button-icon-color-hover: var(--color-text-label);
--icon_button-background-base: var(--color-background-primary);
--icon_button-background-hover: var(--color-background-primary);
--icon_button-border-color-base: var(--color-background-primary);
--icon_button-border-color-hover: var(--color-border-secondary);
--input-text-color: var(--color-text-body);
--input-border-color-base: var(--color-border-primary);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-focus: var(--color-border-primary);
--input-background-base: var(--color-background-tertiary);
--input-background-hover: var(--color-background-tertiary);
--input-background-focus: var(--color-background-tertiary);
--input-shadow: var(--shadow-inset);
--checkbox-border-color-base: var(--color-border-primary);
--checkbox-border-color-hover: var(--color-focus-primary);
--checkbox-border-color-focus: var(--color-blue-500);
--checkbox-background-base: var(--color-background-primary);
--checkbox-background-hover: var(--color-background-primary);
--checkbox-background-focus: var(--color-background-primary);
--checkbox-background-selected: var(--color-blue-600);
--checkbox-label-border-color-base: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-label-border-color-focus: var(--color-border-secondary);
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--form-seperator-color: var(--color-border-primary);
--button-primary-border-color-base: var(--color-orange-600);
--button-primary-border-color-hover: var(--color-orange-600);
--button-primary-border-color-focus: var(--color-orange-600);
--button-primary-text-color-base: white;
--button-primary-text-color-hover: white;
--button-primary-text-color-focus: white;
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-secondary-border-color-base: var(--color-grey-600);
--button-secondary-border-color-hover: var(--color-grey-600);
--button-secondary-border-color-focus: var(--color-grey-600);
--button-secondary-text-color-base: white;
--button-secondary-text-color-hover: white;
--button-secondary-text-color-focus: white;
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-cancel-border-color-base: var(--color-red-600);
--button-cancel-border-color-hover: var(--color-red-600);
--button-cancel-border-color-focus: var(--color-red-600);
--button-cancel-text-color-base: white;
--button-cancel-text-color-hover: white;
--button-cancel-text-color-focus: white;
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-plain-border-color-base: var(--color-grey-600);
--button-plain-border-color-hover: var(--color-grey-500);
--button-plain-border-color-focus: var(--color-grey-500);
--button-plain-text-color-base: var(--color-text-body);
--button-plain-text-color-hover: var(--color-text-body);
--button-plain-text-color-focus: var(--color-text-body);
--button-plain-background-base: var(--color-grey-700);
--button-plain-background-hover: var(--color-grey-700);
--button-plain-background-focus: var(--color-grey-700);
--gallery-label-background-base: var(--color-grey-50);
--gallery-label-background-hover: var(--color-grey-50);
--gallery-label-border-color-base: var(--color-border-primary);
--gallery-label-border-color-hover: var(--color-border-primary);
--gallery-thumb-background-base: var(--color-grey-900);
--gallery-thumb-background-hover: var(--color-grey-900);
--gallery-thumb-border-color-base: var(--color-border-primary);
--gallery-thumb-border-color-hover: var(--color-accent-base);
--gallery-thumb-border-color-focus: var(--color-blue-500);
--gallery-thumb-border-color-selected: var(--color-accent-base);
--chatbot-border-border-color-base: transparent;
--chatbot-border-border-color-latest: transparent;
--chatbot-user-background-base: ;
--chatbot-user-background-latest: ;
--chatbot-user-text-color-base: white;
--chatbot-user-text-color-latest: white;
--chatbot-bot-background-base: ;
--chatbot-bot-background-latest: ;
--chatbot-bot-text-color-base: white;
--chatbot-bot-text-color-latest: white;
--label-gradient-from: var(--color-orange-400);
--label-gradient-to: var(--color-orange-600);
--table-odd-background: var(--color-grey-900);
--table-even-background: var(--color-grey-950);
--table-background-edit: transparent;
--dataset-gallery-background-base: var(--color-background-primary);
--dataset-gallery-background-hover: var(--color-grey-800);
--dataset-dataframe-border-base: var(--color-border-primary);
--dataset-dataframe-border-hover: var(--color-border-secondary);
--dataset-table-background-base: transparent;
--dataset-table-background-hover: var(--color-grey-700);
--dataset-table-border-base: var(--color-grey-800);
--dataset-table-border-hover: var(--color-grey-800);
}
/* SHARK theme customization */
.gradio-container {
background-color: var(--color-background-primary);
}
.container {
background-color: black !important;
padding-top: 20px !important;
}
#ui_title {
padding: 10px !important;
}
#top_logo {
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#demo_title {
background-color: var(--color-background-primary);
border-radius: 0 !important;
border: 0;
padding-top: 15px;
padding-bottom: 0px;
width: 350px !important;
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea {
background-color: var(--color-background-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
background-color: var(--color-background-secondary) !important;
padding: 10px !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}

View File

@@ -1,264 +0,0 @@
import os
import sys
from pathlib import Path
import glob
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
import gradio as gr
from PIL import Image
from apps.stable_diffusion.src import (
prompt_examples,
args,
get_available_devices,
)
from apps.stable_diffusion.scripts import txt2img_inf
nodlogo_loc = resource_path("logos/nod-logo.png")
sdlogo_loc = resource_path("logos/sd-demo-logo.png")
demo_css = resource_path("css/sd_dark_theme.css")
with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
logo2 = Image.open(sdlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
).style(width=150, height=100)
with gr.Column(scale=5, elem_id="demo_title_outer"):
gr.Image(
value=logo2,
show_label=False,
interactive=False,
elem_id="demo_title",
).style(width=150, height=100)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
ckpt_path = (
Path(args.ckpt_dir)
if args.ckpt_dir
else Path(Path.cwd(), "models")
)
ckpt_path.mkdir(parents=True, exist_ok=True)
types = (
"*.ckpt",
"*.safetensors",
) # the tuple of file types
ckpt_files = ["None"]
for extn in types:
files = glob.glob(os.path.join(ckpt_path, extn))
ckpt_files.extend(files)
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {ckpt_path})",
value="None",
choices=ckpt_files
+ [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
hf_model_id = gr.Textbox(
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value="cyberpunk forest by Salvador Dali",
lines=1,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="trees, green",
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
value="SharkEulerDiscrete",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
"EulerDiscrete",
"EulerAncestralDiscrete",
"SharkEulerDiscrete",
],
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
)
width = gr.Slider(
384, 786, value=512, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=64,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance_scale = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
10,
value=1,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Row():
seed = gr.Number(value=-1, precision=0, label="Seed")
available_devices = get_available_devices()
device = gr.Dropdown(
label="Device",
value=available_devices[0],
choices=available_devices,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
None,
inputs=[],
outputs=[seed],
_js="() => Math.floor(Math.random() * 4294967295)",
)
stable_diffusion = gr.Button("Generate Image")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=[2], height="auto")
std_output = gr.Textbox(
value="Nothing to show.",
lines=4,
show_label=False,
)
output_dir = args.output_dir if args.output_dir else Path.cwd()
output_dir = Path(output_dir, "generated_imgs")
output_loc = gr.Textbox(
label="Saving Images at",
value=output_dir,
interactive=False,
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
shark_web.queue()
shark_web.launch(
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=args.server_port,
)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -1,22 +0,0 @@
import torch
from shark.parser import parser
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
parser.add_argument(
"--model_name",
type=str,
required=True,
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
)
load_args, unknown = parser.parse_known_args()
if __name__ == "__main__":
model_name = load_args.model_name
test_input = torch.randint(2, (1, 128))
shark_module = SharkHFBenchmarkRunner(
model_name, (test_input,), jit_trace=True
)
shark_module.benchmark_c()
shark_module.benchmark_python((test_input,))
shark_module.benchmark_torch(test_input)
shark_module.benchmark_onnx(test_input)

View File

@@ -1,181 +0,0 @@
import torch
from shark.shark_benchmark_runner import SharkBenchmarkRunner
from shark.parser import shark_args
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from onnxruntime.transformers.benchmark import (
run_pytorch,
run_tensorflow,
run_onnxruntime,
)
from onnxruntime.transformers.huggingface_models import MODELS
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
import os
import psutil
class OnnxFusionOptions(object):
def __init__(self):
self.disable_gelu = False
self.disable_layer_norm = False
self.disable_attention = False
self.disable_skip_layer_norm = False
self.disable_embed_layer_norm = False
self.disable_bias_skip_layer_norm = False
self.disable_bias_gelu = False
self.enable_gelu_approximation = False
self.use_mask_index = False
self.no_attention_mask = False
class HuggingFaceLanguage(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
# SharkRunner derived class with Benchmarking capabilities.
def __init__(
self,
model_name: str,
input: tuple,
dynamic: bool = False,
device: str = None,
jit_trace: bool = False,
from_aot: bool = False,
frontend: str = "torch",
):
self.device = device if device is not None else shark_args.device
if self.device == "gpu":
raise ValueError(
"Currently GPU Benchmarking is not supported due to OOM from ORT."
)
self.model_name = model_name
model = HuggingFaceLanguage(model_name)
SharkBenchmarkRunner.__init__(
self,
model,
input,
dynamic,
self.device,
jit_trace,
from_aot,
frontend,
)
def benchmark_torch(self, inputs):
use_gpu = self.device == "gpu"
# Set set the model's layer number to automatic.
config_modifier = ConfigModifier(None)
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
verbose = False
result = run_pytorch(
use_gpu,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
False,
cache_dir,
verbose,
)
print(
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)
# TODO: Currently non-functional due to TF runtime error. There might be some issue with, initializing TF.
def benchmark_tf(self, inputs):
use_gpu = self.device == "gpu"
# Set set the model's layer number to automatic.
config_modifier = ConfigModifier(None)
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
verbose = False
result = run_tensorflow(
use_gpu,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
cache_dir,
verbose,
)
print(
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)
def benchmark_onnx(self, inputs):
if self.model_name not in MODELS:
print(
f"{self.model_name} is currently not supported in ORT's HF. Check \
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
for currently supported models. Exiting benchmark ONNX."
)
return
use_gpu = self.device == "gpu"
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
onnx_dir = os.path.join(".", "onnx_models")
verbose = False
input_counts = [1]
optimize_onnx = True
validate_onnx = False
disable_ort_io_binding = False
use_raw_attention_mask = True
model_fusion_statistics = {}
overwrite = False
model_source = "pt" # Either "pt" or "tf"
provider = None
config_modifier = ConfigModifier(None)
onnx_args = OnnxFusionOptions()
result = run_onnxruntime(
use_gpu,
provider,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
input_counts,
optimize_onnx,
validate_onnx,
cache_dir,
onnx_dir,
verbose,
overwrite,
disable_ort_io_binding,
use_raw_attention_mask,
model_fusion_statistics,
model_source,
onnx_args,
)
print(
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)

View File

@@ -1,231 +0,0 @@
from shark.shark_inference import SharkInference
from shark.iree_utils._common import check_device_drivers
import torch
import tensorflow as tf
import numpy as np
import torchvision.models as models
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
import importlib
import pytest
import unittest
torch.manual_seed(0)
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
##################### Tensorflow Hugging Face LM Models ###################################
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
class TFHuggingFaceLanguage(tf.Module):
def __init__(self, hf_model_name):
super(TFHuggingFaceLanguage, self).__init__()
# Create a BERT trainer with the created network.
self.m = TFBertModel.from_pretrained(hf_model_name, from_pt=True)
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m.predict = lambda x, y, z: self.m.call(
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=tf_bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(name)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0
)
test_input = (
encoded_input["input_ids"],
encoded_input["attention_mask"],
encoded_input["token_type_ids"],
)
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Hugging Face LM Models ###################################
class HuggingFaceLanguage(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
def get_hf_model(name):
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (1, 128))
actual_out = model(test_input)
return model, test_input, actual_out
################################################################################
##################### Torch Vision Models ###################################
class VisionModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.train(False)
def forward(self, input):
return self.model.forward(input)
def get_vision_model(torch_model):
model = VisionModule(torch_model)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
return model, test_input, actual_out
############################# Benchmark Tests ####################################
pytest_benchmark_param = pytest.mark.parametrize(
("dynamic", "device"),
[
pytest.param(False, "cpu"),
# TODO: Language models are failing for dynamic case..
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"gpu",
marks=pytest.mark.skipif(
check_device_drivers("gpu"), reason="nvidia-smi not found"
),
),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",
marks=pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
),
),
pytest.param(
True,
"vulkan",
marks=pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
),
),
],
)
@pytest.mark.skipif(
importlib.util.find_spec("iree.tools") is None,
reason="Cannot find tools to import TF",
)
@pytest_benchmark_param
def test_bench_minilm_torch(dynamic, device):
model, test_input, act_out = get_hf_model(
"microsoft/MiniLM-L12-H384-uncased"
)
shark_module = SharkInference(
model,
(test_input,),
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.compile()
shark_module.benchmark_all((test_input,))
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False
@pytest.mark.skipif(
importlib.util.find_spec("iree.tools") is None,
reason="Cannot find tools to import TF",
)
@pytest_benchmark_param
def test_bench_distilbert(dynamic, device):
model, test_input, act_out = get_TFhf_model("distilbert-base-uncased")
shark_module = SharkInference(
model,
test_input,
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False
@pytest.mark.skip(reason="XLM Roberta too large to test.")
@pytest_benchmark_param
def test_bench_xlm_roberta(dynamic, device):
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
shark_module = SharkInference(
model,
test_input,
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False

View File

@@ -1,45 +0,0 @@
import torch
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
import importlib
import pytest
torch.manual_seed(0)
############################# HF Benchmark Tests ####################################
# Test running benchmark module without failing.
pytest_benchmark_param = pytest.mark.parametrize(
("dynamic", "device"),
[
pytest.param(False, "cpu"),
# TODO: Language models are failing for dynamic case..
pytest.param(True, "cpu", marks=pytest.mark.skip),
],
)
@pytest.mark.skipif(
importlib.util.find_spec("onnxruntime") is None,
reason="Cannot find ONNXRUNTIME.",
)
@pytest_benchmark_param
def test_HFbench_minilm_torch(dynamic, device):
model_name = "bert-base-uncased"
test_input = torch.randint(2, (1, 128))
try:
shark_module = SharkHFBenchmarkRunner(
model_name,
(test_input,),
jit_trace=True,
dynamic=dynamic,
device=device,
)
shark_module.benchmark_c()
shark_module.benchmark_python((test_input,))
shark_module.benchmark_torch(test_input)
shark_module.benchmark_onnx(test_input)
# If becnhmarking succesful, assert success/True.
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False

Some files were not shown because too many files have changed in this diff Show More