Compare commits

...

34 Commits

Author SHA1 Message Date
saienduri
4529fd0461 Update requirements.txt 2024-08-06 19:29:40 -07:00
saienduri
4c2bb4b7b4 Update requirements.txt 2024-08-06 17:15:42 -07:00
saienduri
d5013fd13e Update requirements.txt (#2157) 2024-06-18 13:41:35 -07:00
Ean Garvey
26f80ccbbb Fixes to UI config defaults, config loading, and warnings. (#2153) 2024-05-31 18:14:27 -04:00
Ean Garvey
d2c3752dc7 Fix batch count and tweaks to chatbot. (#2151)
* Fix batch count

* Add button to unload models manually.

* Add compiled pipeline option

* Add brevitas to requirements

* Tweaks to chatbot

* Change script loading trigger
2024-05-31 18:48:28 +05:30
Ean Garvey
4505c4549f Force inlined weights on igpu for now, small fixes to chatbot (#2149)
* Add igpu and custom triple support.

* Small fixes to igpu, SDXL-turbo

* custom pipe loading

* formatting

* Remove old nodlogo import.
2024-05-30 11:40:42 -05:00
Gaurav Shukla
793495c9c6 [ui] Add AMD logo in shark studio
Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
2024-05-30 21:43:15 +05:30
Ean Garvey
13e1d8d98a Add igpu and custom triple support. (#2148) 2024-05-29 17:39:36 -05:00
Ean Garvey
2074df40ad Point to nod fork of diffusers. (#2146) 2024-05-29 00:56:21 -05:00
Ean Garvey
7b30582408 Point to SRT links for windows. (#2145) 2024-05-29 01:20:30 -04:00
Ean Garvey
151195ab74 Add a few requirements for ensured parity with turbine-models requirements. (#2142)
* Add scipy to requirements.

Adds diffusers req and a note for torchsde.
2024-05-28 15:37:31 -05:00
Ean Garvey
8146f0bd2f Remove leftover merge conflict line from setup script. (#2141) 2024-05-28 11:04:45 -07:00
Ean Garvey
68e9281778 (Studio2) Refactors SD pipeline to rely on turbine-models pipeline, fixes to LLM, gitignore (#2129)
* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Fixups to llm API/UI and ignore user config files.

* Small fixes for unifying pipelines.

* Update requirements.txt for iree-turbine (#2130)

* Fix Llama2 on CPU (#2133)

* Filesystem cleanup and custom model fixes (#2127)

* Fix some formatting issues

* Remove IREE pin (fixes exe issue) (#2126)

* Update find links for IREE packages (#2136)

* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes

* Abstract out SD pipelines from Studio Webui (WIP)

* Switch from pin to minimum torch version and fix index url

* Fix device parsing.

* Fix linux setup

* Fix custom weights.

---------

Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
2024-05-28 13:18:31 -04:00
Ean Garvey
fd07cae991 Update find links for IREE packages (#2136) 2024-05-13 11:43:17 -05:00
gpetters94
6cb86a843e Remove IREE pin (fixes exe issue) (#2126)
* Diagnose a build issue

* Remove IREE pin

* Revert the build on pull request change
2024-04-30 12:27:30 -05:00
gpetters-amd
7db1612a5c Filesystem cleanup and custom model fixes (#2127)
* Initial filesystem cleanup

* More filesystem cleanup

* Fix some formatting issues

* Address comments
2024-04-30 11:18:33 -05:00
gpetters-amd
81d6e059ac Fix Llama2 on CPU (#2133) 2024-04-29 12:18:16 -05:00
saienduri
e003d0abe8 Update requirements.txt for iree-turbine (#2130)
* Update requirements.txt to iree-turbine creation

* Update requirements.txt

* Update requirements.txt

* Update requirements.txt
2024-04-29 12:28:14 -04:00
Quinn Dawkins
cf2513e7b1 Update IREE discord link (#2118)
Discord links for IREE were purged, so update the link on the readme.
2024-04-15 12:54:27 -07:00
Ean Garvey
60d8591e95 Change shark-turbine requirement target branch to main. (#2116) 2024-04-11 19:31:39 -04:00
gpetters-amd
ff91982168 Remove target env (#2114) 2024-04-08 16:52:45 -05:00
powderluv
a6a9e524c1 Drop linux nightly for now 2024-04-05 12:04:36 -07:00
powderluv
732df2e263 Updated signtool key 2024-04-05 12:01:42 -07:00
gpetters-amd
1ee16bd256 Fix the nightly build (#2111) 2024-04-05 19:22:33 +05:30
gpetters-amd
752d775fbd Fix a typo in the nightly build script (#2110) 2024-03-30 17:31:51 -07:00
gpetters-amd
4d1a6a204d Fix builder issue (#2109) 2024-03-30 16:21:55 -07:00
Ean Garvey
0eff62a468 (Studio 2.0) add Stable Diffusion features (#2037)
* (WIP): Studio2 app infra and SD API

UI/app structure and utility implementation.

- Initializers for webui/API launch
- Schedulers file for SD scheduling utilities
- Additions to API-level utilities
- Added embeddings module for LoRA, Lycoris, yada yada
- Added image_processing module for resamplers, resize tools,
  transforms, and any image annotation (PNG metadata)
- shared_cmd_opts module -- sorry, this is stable_args.py. It lives on.
  We still want to have some global control over the app exclusively
  from the command-line. At least we will be free from shark_args.
- Moving around some utility pieces.
- Try to make api+webui concurrency possible in index.py
- SD UI -- this is just img2imgUI but hopefully a little better.
- UI utilities for your nod logos and your gradio temps.

Enable UI / bugfixes / tweaks

* Studio2/SD: Use more correct LoRA alpha calculation (#2034)

* Updates ProcessLoRA to use both embedded LoRA alpha, and lora_strength
optional parameter (default 1.0) when applying LoRA weights.
* Updates ProcessLoRA to cover more dim cases.
* This bring ProcessLoRA into line with PR #2015 against Studio1

* Studio2: Remove duplications from api/utils.py (#2035)

* Remove duplicate os import
* Remove duplicate parse_seed_input function

Migrating to JSON requests in SD UI

More UI and app flow improvements, logging, shared device cache

Model loading

Complete SD pipeline.

Tweaks to VAE, pipeline states

Pipeline tweaks, add cmd_opts parsing to sd api

* Add test for SD

* Small cleanup

* Shark2/SD/UI: Respect ckpt_dir, share and server_port args (#2070)

* Takes whether to generate a gradio live link from the existing --share command
line parameter, rather than hardcoding as True.
* Takes server port from existing --server_port command line parameter, rather than
hardcoding as 11911.
* Default --ckpt_dir parameter to '../models'
* Use --ckpt_dir rather than hardcoding ../models as the base directory for
checkpoints, vae, and lora, etc
* Add a 'checkpoints' directory below --ckpt_dir to match ComfyUI folder structure.
Read custom_weights choices from there, and/or subfolders below there matching
the selected base model.
* Fix --ckpt_dir possibly not working correctly when an absolute rather than relative path
is specified.
* Relabel "Custom Weights" to "Custom Weights Checkpoint" in the UI

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)

* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.

* Add rest API endpoint from LanguageModel API

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* Formatting and init files.

* Remove unused import.

* Small fixes

* Studio2/SD/UI: Improve various parts of the UI for Stable Diffusion (#2074)

* Studio2/SD/UI: Improve various parts of the UI of Shark 2

* Update Gradio pin to 4.15.0.
* Port workarounds for Gradio >4.8.0 main container sizing from Shark 1.0.
* Move nod Logo out of the SD tab and onto the top right of the main tab bar.
* Set nod logo icon as the favicon (as current Shark 1.0).
* Create a tabbed right hand panel within the SD UI sized to the viewport height.
* Make Input Image tab 1 in the right hand panel.
* Make output images, generation log, and  generation buttons, tab 2 in the
right hand panel
* Make config JSON display, with config load, save and clear, tab 3 in the
right hand panel
* Make gallery  area of the Output tab take up all vertical space the other controls
on the tab do not.
* Tidy up the controls on the Config tab somewhat.

* Studio2/SD/UI: Reorganise inputs on Left Panel of SD tab

* Rename previously added Right Panel Output tab to 'Generate'.
* Move Batch Count, Batch Size, and Repeatable Seeds, off of Left Panel and onto 'Generate' Tab.
* On 'Generate' tab, rename 'Generate Image(s)' button to 'Start', and 'Stop Batch' button to 'Stop'. They are now below the Batch inputs on a Generate tab so don't need the specificity.
* Move Device, Low VRAM, and Precision inputs into their own 'Device Settings' Accordion control. (starts closed)
* Rename 'Custom Weights Checkpoint' to 'Checkpoint Weights'
* Move Checkpoint Weights, VAE Model, Standalone Lora Weights, and Embeddings Options controls, into their own 'Model Weights' Accordion control.  (starts closed)
* Move Denoising Strength, and Resample Type controls into their own 'Input Image Processing' Accordion. (starts closed)
* Move any remaining controls in the 'Advanced Options' Accorion directly onto the left panel, and remove then Accordion.
* Enable the copy button for all text boxes on the SD tab.
* Add emoji/unicode glphs to all top level controls and Accordions on the SD Left Panel.
* Start with the 'Generate' as the initially selected tab in the SD Right Panel, working around Gradio issue #7805
* Tweaks to SD Right Tab Panel vertical height.

* Studio2/SD/UI: Sizing tweaks for Right Panel, and >1920 width

* Set height of right panel using vmin rather than vh, with explicit affordances
for fixed areas above and below.
* Port >1920 width Gradio >4.8 CSS workaround from Shark 1.0.

* Studio2/SD: Fix sd pipeline up to "Windows not supported" (#2082)

* Studio2/SD: Fix sd pipeline up to "Windows not supported"

A number of fixes to the SD pipeline as run from the UI, up until the point that dynamo
complains "Windows not yet supported for torch.compile".

* Remove separate install of iree-runtime and iree-compile in setup_venv.ps1, and rely on the
versions installed via the Turbine requirements.txt. Fixes #2063 for me.
* Replace any "None" strings with python None when pulling the config in the UI.
* Add 'hf_auth_token' param to api StableDiffusion class, defaulting to None, and then pass
that in to the various Models where it is required and wasn't already being done before.
* Fix clip custom_weight_params being passed to export_clip_model as "external_weight_file"
rather than "external_weights"
* Don't pass non-existing "custom_vae" parameter to the Turbine Vae Model, instead
pass custom_vae as the "hf_model_id" if it is set. (this may be wrong in the custom vae
cast, but stops the code *always* breaking).

* Studio2/SD/UI: Improve UI config None handling

* When populating the UI from a JSON Config set controls to "None" for null/None
values.
* When generating a JSON Config from the UI set props to null/None for controls
set to "None".
* Use null rather string 'None' in the default config

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>

* Studio2/SD/UI: Further sd ui pipeline fixes (#2091)

On Windows, this gets us all the way failing in iree compile of the with SD 2.1 base.

- Fix merge errors with sd right pane config UI tab.
- Remove non-requirement.txt install/build of torch/mlir/iree/SRT in setup_venv.ps1, fixing "torch.compile not supported on Windows" error.
- Fix gradio deprecation warning for `root=` FileExplorer kwarg.
- Comment out `precision` and `max_length` kwargs being passed to unet, as not yet supported on main Turbine branch. Avoids keyword argument error.

* Tweak compile-time flags for SD submodels.

* Small fixes to sd, pin mpmath

* Add pyinstaller spec and imports script.

* Fix the .exe (#2101)

* Fix _IREE_TARGET_MAP (#2103) (#2108)

- Change target passed to iree for vulkan from 'vulkan'
to 'vulkan-spriv', as 'vulkan' is not a valid value for
--iree-hal-target-backends with the current iree compiler.

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>

* Cleanup sd model map.

* Update dependencies.

* Studio2/SD/UI: Update gradio to 4.19.2 (sd-studio2) (#2097)

- Move pin for gradio from 4.15 -> 4.19.2 on the sd-studio2 branch

* fix formatting and disable explicit vulkan env settings.

---------

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
Co-authored-by: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
2024-03-29 18:13:21 -04:00
dependabot[bot]
5a5de545c9 Bump gradio from 3.34.0 to 4.19.2 in /dataset (#2093)
Bumps [gradio](https://github.com/gradio-app/gradio) from 3.34.0 to 4.19.2.
- [Release notes](https://github.com/gradio-app/gradio/releases)
- [Changelog](https://github.com/gradio-app/gradio/blob/main/CHANGELOG.md)
- [Commits](https://github.com/gradio-app/gradio/compare/v3.34.0...gradio@4.19.2)

---
updated-dependencies:
- dependency-name: gradio
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
2024-03-28 10:01:26 -05:00
Stefan Kapusniak
58f194a450 Fix _IREE_TARGET_MAP (#2103)
- Change target passed to iree for vulkan from 'vulkan'
to 'vulkan-spriv', as 'vulkan' is not a valid value for
--iree-hal-target-backends with the current iree compiler.
2024-03-18 00:21:44 -05:00
Stefan Kapusniak
c5cf005292 Add *.safetensors to .gitignore (#2089)
- shark2 is putting base model .safetensors file in model specific subfolders. Easiest to just ignore
.safetensors completely.
2024-02-17 21:37:47 -06:00
Stefan Kapusniak
12094ec49c Update README to recommend SHARK-1.0 for now (#2087)
- Add a note at the of the top of the README to use SHARK-1.0 whilst
rewrite with Turbine is ongoing.
- Update installation section to point at SHARK-1.0 branch.
- Suggest the latest SHARK-1.0 pre-release as well as stable.
- Recommend running the .exe from the command line.
2024-02-08 14:22:27 -06:00
Daniel Garvey
100e5b8244 address refactor in turbine (#2086)
python/turbine_models -> models
shark-turbine -> core
2024-02-05 13:05:01 -08:00
Stanley Winata
6bf51f1f1d HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)
* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.
2024-02-01 11:46:22 -06:00
Ean Garvey
05b498267e Add StreamingLLM support to studio2 chat (#2060)
* Streaming LLM 

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.
2024-01-18 19:01:07 -06:00
65 changed files with 7338 additions and 987 deletions

View File

@@ -50,10 +50,11 @@ jobs:
shell: powershell
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip install -e .
pip freeze -l
pyinstaller .\apps\shark_studio\shark_studio.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
@@ -74,80 +75,3 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -81,6 +81,5 @@ jobs:
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
# Disabled due to hang when exporting test llama2
# python apps/shark_studio/tests/api_test.py

9
.gitignore vendored
View File

@@ -164,7 +164,7 @@ cython_debug/
# vscode related
.vscode
# Shark related artefacts
# Shark related artifacts
*venv/
shark_tmp/
*.vmfb
@@ -172,6 +172,7 @@ shark_tmp/
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs
# ORT related artefacts
cache_models/
@@ -183,10 +184,16 @@ generated_imgs/
# Custom model related artefacts
variants.json
/models/
*.safetensors
# models folder
apps/stable_diffusion/web/models/
# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb
# Stencil annotators.
stencil_annotator/

View File

@@ -2,18 +2,20 @@
High Performance Machine Learning Distribution
*We are currently rebuilding SHARK to take advantage of [Turbine](https://github.com/nod-ai/SHARK-Turbine). Until that is complete make sure you use an .exe release or a checkout of the `SHARK-1.0` branch, for a working SHARK*
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
@@ -22,23 +24,23 @@ Other users please ensure you have your latest vendor drivers and Vulkan SDK fro
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Install the Driver from (Prerequisites)[https://github.com/nod-ai/SHARK#install-your-hardware-drivers] above
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
Download the [stable release](https://github.com/nod-ai/shark/releases/latest) or the most recent [SHARK 1.0 pre-release](https://github.com/nod-ai/shark/releases).
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
Double click the .exe, or [run from the command line](#running) (recommended), and you should have the [UI](http://localhost:8080/) in the browser.
If you have custom models put them in a `models/` directory where the .exe is.
If you have custom models put them in a `models/` directory where the .exe is.
Enjoy.
Enjoy.
<details>
<summary>More installation notes</summary>
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
## Running
@@ -46,17 +48,22 @@ Enjoy.
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
* If you prefer to always run in the browser, use the `--ui=web` command argument when running the EXE.
## Stopping
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
</details>
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
### Windows 10/11 Users
* Install Git for Windows from [here](https://git-scm.com/download/win) if you don't already have it.
## Check out the code
```shell
@@ -64,14 +71,22 @@ git clone https://github.com/nod-ai/SHARK.git
cd SHARK
```
## Switch to the Correct Branch (IMPORTANT!)
Currently SHARK is being rebuilt for [Turbine](https://github.com/nod-ai/SHARK-Turbine) on the `main` branch. For now you are strongly discouraged from using `main` unless you are working on the rebuild effort, and should not expect the code there to produce a working application for Image Generation, So for now you'll need switch over to the `SHARK-1.0` branch and use the stable code.
```shell
git checkout SHARK-1.0
```
The following setup instructions assume you are on this branch.
## Setup your Python VirtualEnvironment and Dependencies
### Windows 10/11 Users
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
@@ -86,21 +101,20 @@ set-executionpolicy remotesigned
```shell
./setup_venv.sh
source shark.venv/bin/activate
source shark1.venv/bin/activate
```
### Run Stable Diffusion on your device - WebUI
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
(shark1.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark1.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
#### Linux / macOS Users
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
(shark1.venv) > cd apps/stable_diffusion/web
(shark1.venv) > python index.py
```
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
@@ -114,7 +128,7 @@ source shark.venv/bin/activate
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
(shark1.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
#### Linux / macOS Users
@@ -142,7 +156,7 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
<details>
@@ -205,7 +219,7 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
Set `USE_IREE=1` to use upstream IREE
```
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
```
### Run any of the hundreds of SHARK tank models via the test framework
@@ -214,7 +228,7 @@ python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use g
# Or a pytest
pytest tank/test_models.py -k "MiniLM"
```
### How to use your locally built IREE / Torch-MLIR with SHARK
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
@@ -240,12 +254,12 @@ Now the SHARK will use your locally build Torch-MLIR repo.
## Benchmarking Dispatches
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
```
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
```
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
@@ -264,7 +278,7 @@ shark_module = SharkInference(
Output will include:
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
- An .mlir file containing the dispatch benchmark
- An .mlir file containing the dispatch benchmark
- A compiled .vmfb file containing the dispatch benchmark
- An .mlir file containing just the hal executable
- A compiled .vmfb file of the hal executable
@@ -332,7 +346,7 @@ result = shark_module.forward((arg0, arg1))
## Supported and Validated Models
SHARK is maintained to support the latest innovations in ML Models:
SHARK is maintained to support the latest innovations in ML Models:
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|---------------------|----------|----------|-------------|
@@ -358,7 +372,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
bugs, and other work tracking
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
discussions with the core team and collaborators
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
Announcements, general and low-priority discussion
@@ -373,7 +387,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
* Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
* [MLIR topic within LLVM Discourse](https://llvm.discourse.group/c/llvm-project/mlir/31) SHARK and IREE is enabled by and heavily relies on [MLIR](https://mlir.llvm.org).
</details>
## License
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions.

View File

@@ -0,0 +1,107 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)
class control_adapter:
def __init__(
self,
model: str,
):
self.model = None
def export_control_adapter_model(model_keyword):
return None
def export_xl_control_adapter_model(model_keyword):
return None
class preprocessors:
def __init__(
self,
model: str,
):
self.model = None
def export_controlnet_model(model_keyword):
return None
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}
class PreprocessorModel:
def __init__(
self,
hf_model_id,
device="cpu",
):
self.model = hf_model_id
self.device = device
def compile(self):
print("compile not implemented for preprocessor.")
return
def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image),
100,
200,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
return None, None

View File

@@ -0,0 +1,125 @@
import importlib
import os
import signal
import sys
import warnings
import json
from threading import Thread
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
)
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
import gradio # noqa: F401
startup_timer.record("import gradio")
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
startup_timer.record("initialize globals")
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()
from apps.shark_studio.web.utils.file_utils import (
create_model_folders,
)
# Create custom models folders if they don't exist
create_model_folders()
import gradio as gr
# initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")
app.add_middleware(CORSMiddleware, **cors_options)
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f"Interrupted with signal {sig} in {frame}")
dumpstacks()
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -1,21 +1,27 @@
from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
llm_model_map = {
"llama2_7b": {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
@@ -23,12 +29,34 @@ llm_model_map = {
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
"max_tokens": 1024,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<s>", "</s>"
DEFAULT_CHAT_SYS_PROMPT = """<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <</SYS>>\n\n
"""
def append_user_prompt(history, input_prompt):
user_prompt = f"{B_INST} {input_prompt} {E_INST}"
history += user_prompt
return history
class LanguageModel:
def __init__(
@@ -36,53 +64,114 @@ class LanguageModel:
model_name,
hf_auth_token=None,
device=None,
precision="fp32",
quantization="int4",
precision="",
external_weights=None,
use_system_prompt=True,
streaming_llm=False,
):
print(llm_model_map[model_name])
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.tempfile_name = get_resource_path("llm.torch.tempfile")
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
self.driver = self.backend
if "cpu" in device:
self.device = "cpu"
self.backend = "llvm-cpu"
self.driver = "local-task"
print(f"Selected {self.backend} as IREE target backend.")
self.precision = "f32" if "cpu" in device else "f16"
self.quantization = quantization
self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_")
self.external_weight_file = None
# TODO: find a programmatic solution for model arch spec instead of hardcoding llama2
self.file_spec = "_".join(
[
self.safe_name,
self.precision,
]
)
if self.quantization != "None":
self.file_spec += "_" + self.quantization
if external_weights in ["safetensors", "gguf"]:
self.external_weight_file = get_resource_path(
os.path.join("..", self.file_spec + "." + external_weights)
)
else:
self.external_weights = None
self.external_weight_file = None
if streaming_llm:
# Add streaming suffix to file spec after setting external weights filename.
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm
self.tempfile_name = get_resource_path(
os.path.join("..", f"{self.file_spec}.tempfile")
)
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = str(
get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
)
)
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = None
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
)
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
)
self.external_weight_file = str(self.external_weight_file)
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
self.iree_module_dict["config"],
self.iree_module_dict["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.vmfb_name,
device,
device_idx=0,
rt_flags=[],
external_weight_file=self.external_weight_file,
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
external_weight_file=self.external_weight_file,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
@@ -96,22 +185,58 @@ class LanguageModel:
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
# ONLY architecture/api-specific compile-time flags for each backend, if needed.
# hf_model_id-specific global flags currently in model map.
flags = []
if "cpu" in self.backend:
flags.extend(
[
"--iree-global-opt-enable-quantized-matmul-reassociation",
]
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"],
)
# TODO: delete the temp file
self.runner = vmfbRunner(
device=self.driver,
vmfb_path=self.vmfb_name,
external_weight_path=self.external_weight_file,
)
if self.streaming_llm:
self.model = self.runner.ctx.modules.streaming_state_update
else:
self.model = self.runner.ctx.modules.state_update
def sanitize_prompt(self, prompt):
print(prompt)
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
@@ -119,10 +244,10 @@ class LanguageModel:
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt
prompt += " [/INST]"
print(prompt)
return prompt
prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt)
return prompt
else:
return f"{B_INST} {prompt} {E_INST}"
def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)
@@ -134,28 +259,45 @@ class LanguageModel:
history = []
for iter in range(self.max_tokens):
st_time = time.time()
if iter == 0:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
if self.streaming_llm:
token_slice = max(self.prev_token_len - 1, 0)
input_tensor = input_tensor[:, token_slice:]
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token_len = input_tensor.shape[-1]
device_inputs = [
ireert.asdevicearray(self.runner.config.device, input_tensor)
]
if self.first_input or not self.streaming_llm:
st_time = time.time()
token = self.model["run_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
self.first_input = False
else:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)
st_time = time.time()
token = self.model["run_cached_initialize"](*device_inputs)
total_time = time.time() - st_time
token_len += 1
total_time = time.time() - st_time
history.append(format_out(token))
yield self.tokenizer.decode(history), total_time
while (
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
self.model["evict_kvcache_space"]()
token = self.model["run_forward"](token)
history.append(format_out(token))
total_time = time.time() - dec_time
yield self.tokenizer.decode(history), total_time
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
self.prev_token_len = token_len + len(history)
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
@@ -165,6 +307,160 @@ class LanguageModel:
self.global_iter += 1
return result_output, total_time
# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False
history.append(int(token))
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time
self.prev_token_len = token_len + len(history)
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time
def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path
def llm_chat_api(InputData: dict):
from datetime import datetime as dt
import apps.shark_studio.web.utils.globals as global_obj
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
if is_chat_completion_api:
print(f"message -> role : {InputData['messages'][0]['role']}")
print(f"message -> content : {InputData['messages'][0]['content']}")
else:
print(f"prompt : {InputData['prompt']}")
model_name = (
InputData["model"]
if "model" in InputData.keys()
else "meta-llama/Llama-2-7b-chat-hf"
)
model_path = llm_model_map[model_name]
device = InputData["device"] if "device" in InputData.keys() else "cpu"
precision = "fp16"
max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096
device_id = None
if not global_obj.get_llm_obj():
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
if "cuda" in device:
device = "cuda"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "cpu" in device:
device = "cpu"
precision = "fp32"
else:
print("unrecognized device")
llm_model = LanguageModel(
model_name=model_name,
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
use_system_prompt=True,
streaming_llm=False,
)
global_obj.set_llm_obj(llm_model)
else:
llm_model = global_obj.get_llm_obj()
llm_model.max_tokens = max_tokens
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = append_user_prompt(
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
for res_op, _ in llm_model.chat(prompt):
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
if __name__ == "__main__":
lm = LanguageModel(

505
apps/shark_studio/api/sd.py Normal file
View File

@@ -0,0 +1,505 @@
import gc
import torch
import gradio as gr
import time
import os
import json
import numpy as np
import copy
import importlib.util
import sys
from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.utils import parse_device
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.img_processing import (
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
EMPTY_FLAGS = {
"clip": None,
"unet": None,
"vae": None,
"pipeline": None,
}
def load_script(source, module_name):
"""
reads file source and loads it as a module
:param source: file to load
:param module_name: name of module to register in sys.modules
:return: loaded module
"""
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
class StableDiffusion:
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
def __init__(
self,
base_model_id,
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
is_controlled: bool = False,
external_weights: str = "safetensors",
):
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
self.is_sdxl = "xl" in self.base_model_id.lower()
self.is_custom = ".py" in self.base_model_id.lower()
if self.is_custom:
custom_module = load_script(
os.path.join(get_checkpoints_path("scripts"), self.base_model_id),
"custom_pipeline",
)
self.turbine_pipe = custom_module.StudioPipeline
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
self.turbine_pipe = SharkSDXLPipeline
self.model_map = EMPTY_SDXL_MAP
else:
self.turbine_pipe = SharkSDPipeline
self.model_map = EMPTY_SD_MAP
max_length = 64
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103", "gfx1150"]:
decomp_attn = False
attn_spec = "wmma"
if triple in ["gfx1103", "gfx1150"]:
# external weights have issues on igpu
external_weights = None
elif target_backend == "llvm-cpu":
decomp_attn = False
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=max_length,
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline
if custom_weights:
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
custom_weights,
)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(
diffusers_weights_path,
"unet",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")
elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
clip_2_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.safetensors",
)
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2."),
]
elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
return
def generate_images(
self,
prompt,
negative_prompt,
image,
strength,
guidance_scale,
seed,
ondemand,
resample_type,
control_mode,
hints,
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
return_imgs=True,
)
return img
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("\n[LOG] Submitting Request...")
for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
sd_kwargs[key] = None
if sd_kwargs[key] in ["None"]:
sd_kwargs[key] = ""
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])
# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs
def shark_sd_fn(
prompt,
negative_prompt,
sd_init_image: list,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: list,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
adapters = {}
is_controlled = False
control_mode = None
hints = []
num_loras = 0
import_ir = True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
model
],
"strength": controlnets["strength"][i],
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
}
submit_prep_kwargs = {
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
}
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
):
print("\n[LOG] Initializing new pipeline...")
global_obj.clear_cache()
gc.collect()
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.
sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
)
global_obj.set_sd_obj(sd_pipe)
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
if (
not global_obj.get_prep_kwargs()
or global_obj.get_prep_kwargs() != submit_prep_kwargs
):
global_obj.set_prep_kwargs(submit_prep_kwargs)
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
generated_imgs = []
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return (generated_imgs, "")
def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj
global_obj.clear_cache()
gc.collect()
def cancel_sd():
print("Inject call to cancel longer API calls.")
return
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def safe_name(name):
return name.replace("/", "_").replace("\\", "_").replace(".", "_")
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
global_obj._init()
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -1,12 +1,389 @@
import os
import sys
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
def get_available_devices():
return ["cpu-task"]
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
hip_devices = get_devices_by_name("hip")
available_devices.extend(hip_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", igpu_name
)
break
return available_devices
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
)
rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
rt_device = f"{rt_driver}://{device_id}"
else:
rt_device = rt_driver
if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
return target_backend, rt_device, triple
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "llvm-cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
def get_rocm_target_chip(device_str):
# TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
"6900": "gfx1030",
"7900": "gfx1100",
"MI300X": "gfx942",
"MI300A": "gfx940",
"MI210": "gfx90a",
"MI250": "gfx90a",
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
)
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if cmd_opts.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)

View File

@@ -0,0 +1,145 @@
import os
import json
import re
import requests
import torch
import safetensors
from shark_turbine.aot.params import (
ParameterArchiveBuilder,
)
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
create_vae_diffusers_config,
convert_ldm_vae_checkpoint,
)
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return path_to_diffusers
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but
# non-EMA weights have been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
if precision == "fp16":
pipe.to(dtype=torch.float16)
pipe.save_pretrained(path_to_diffusers)
del pipe
print("Loading complete")
return path_to_diffusers
def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
for key in weights.keys():
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])
irpa_file = weights_path.replace(".safetensors", ".irpa")
archive.save(irpa_file)
return irpa_file
def convert_original_vae(vae_checkpoint):
vae_state_dict = {}
for key in list(vae_checkpoint.keys()):
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
config_url = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
"main/configs/stable-diffusion/v1-inference.yaml"
)
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
return converted_vae_checkpoint
def process_custom_pipe_weights(custom_weights):
if custom_weights != "":
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
custom_weights_params = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():
print(f"downloading civitai model from {url} to {destination_path}")
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()

View File

@@ -0,0 +1,185 @@
import os
import sys
import torch
import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
get_path_stem,
)
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
(
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0
),
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
return unet
main_file_name = get_path_stem(use_lora)
if ".bin" in use_lora:
main_file_name += ".bin"
elif ".safetensors" in use_lora:
main_file_name += ".safetensors"
elif ".pt" in use_lora:
main_file_name += ".pt"
else:
sys.exit("Only .bin and .safetensors format for LoRA is supported")
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_checkpoint_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}

View File

@@ -0,0 +1,202 @@
import os
import re
import json
import torch
import numpy as np
from csv import DictWriter
from PIL import Image, PngImagePlugin
from pathlib import Path
from datetime import datetime as dt
from base64 import decode
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed, extra_info=None):
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if extra_info is None:
extra_info = {}
generated_imgs_path = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15])
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
img_model = extra_info["base_model_id"]
if extra_info["custom_weights"] not in [None, "None"]:
img_model = Path(os.path.basename(extra_info["custom_weights"])).stem
img_vae = None
if extra_info["custom_vae"]:
img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem
img_loras = None
if extra_info["embeddings"]:
img_lora = []
for i in extra_info["embeddings"]:
img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem
img_loras = ", ".join(img_lora)
if cmd_opts.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if cmd_opts.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
# if cmd_opts.use_hiresfix:
# png_size_text = (
# f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}"
# )
# else:
png_size_text = f"{extra_info['width']}x{extra_info['height']}"
pngInfo.add_text(
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_loras}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if cmd_opts.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {cmd_opts.output_img_format} is not "
f"supported yet. Image saved as png instead."
f"Supported formats: png / jpg"
)
# To be as low-impact as possible to the existing CSV format, we append
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
# importance for each data point. Something to consider.
new_entry = {}
new_entry.update(extra_info)
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
# For stencil, the input image can be of any size, but we need to ensure that
# it conforms with our model constraints :-
# Both width and height should be in the range of [128, 768] and multiple of 8.
# This utility function performs the transformation on the input image while
# also maintaining the aspect ratio before sending it to the stencil pipeline.
def resize_stencil(image: Image.Image, width, height, resampler_type=None):
aspect_ratio = width / height
min_size = min(width, height)
if min_size < 128:
n_size = 128
if width == min_size:
width = n_size
height = n_size / aspect_ratio
else:
height = n_size
width = n_size * aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
min_size = min(width, height)
if min_size > 768:
n_size = 768
if width == min_size:
height = n_size
width = n_size * aspect_ratio
else:
width = n_size
height = n_size / aspect_ratio
width = int(width)
height = int(height)
n_width = width // 8
n_height = height // 8
n_width *= 8
n_height *= 8
if resampler_type in resamplers:
resampler = resamplers[resampler_type]
else:
resampler = resamplers["Nearest Neighbor"]
new_image = image.resize((n_width, n_height), resampler=resampler)
return new_image, n_width, n_height
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img

View File

@@ -0,0 +1,37 @@
import sys
class Logger:
def __init__(self, filename, filter=None):
self.terminal = sys.stdout
self.log = open(filename, "w")
self.filter = filter
def write(self, message):
for x in message.split("\n"):
if self.filter in x:
self.log.write(message)
else:
self.terminal.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
def logger_test(x):
print("[LOG] This is a test")
print(f"This is another test, without the filter")
return x
def read_sd_logs():
sys.stdout.flush()
with open("shark_tmp/sd.log", "r") as f:
return f.read()
sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]")

View File

@@ -0,0 +1,205 @@
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
)
from apps.shark_studio.modules.shared_cmd_opts import (
cmd_opts,
)
from iree import runtime as ireert
from pathlib import Path
import gc
import os
class SharkPipelineBase:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
# - preparing weights for an inference job
# - loading weights for an inference job
# - utilites like benchmarks, tests
def __init__(
self,
model_map: dict,
base_model_id: str,
static_kwargs: dict,
device: str,
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
# The weights aren't static here anymore so this function should be a part of pipeline
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.pipe_map[submodel] = {}
self.get_precompiled(self.pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
return
elif "vmfb_path" in self.pipe_map[submodel]:
return
elif submodel not in self.tempfiles:
print(
f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
)
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
else:
ireec_flags = (
self.model_map[submodel]["ireec_flags"]
if "ireec_flags" in self.model_map[submodel]
else []
)
weights_path = self.get_io_params(submodel)
if weights_path:
ireec_flags.append("--iree-opt-const-eval=False")
self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
self.get_precompiled(pipe_id, model)
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
)
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
self.tempfiles[submodel] = os.path.join(
self.tmp_dir, f"{submodel}.torch.tempfile"
)
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.pipe_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=weights_path,
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return
def run(self, submodel, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
inp = [
ireert.asdevicearray(
self.iree_module_dict[submodel]["config"].device, input
)
for input in inputs
]
return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]
return flat_args

View File

@@ -0,0 +1,376 @@
from typing import List, Optional, Union
from iree import runtime as ireert
import re
import torch
import numpy as np
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs:
text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens,
weights,
max_length,
bos,
eos,
no_boseos_middle=True,
chunk_length=77,
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j
* (chunk_length - 2) : min(
len(weights[i]), (j + 1) * (chunk_length - 2)
)
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]
def get_weighted_text_embeddings(
pipe,
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(
pipe, prompt, max_length - 2
)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(
prompt, max_length=max_length, truncation=True
).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.model_max_length,
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = (
uncond_embeddings.float()
.mean(axis=[-2, -1])
.to(uncond_embeddings.dtype)
)
uncond_embeddings *= (
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@@ -0,0 +1,118 @@
# from shark_turbine.turbine_models.schedulers import export_scheduler_model
from diffusers import (
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
def get_schedulers(model_id):
# TODO: switch over to turbine and run all on GPU
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DDPM"] = DDPMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DDIM"] = DDIMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver"
# )
# schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
# model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
# )
# schedulers["DPMSolverMultistepKarras"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# use_karras_sigmas=True,
# )
# )
# schedulers["DPMSolverMultistepKarras++"] = (
# DPMSolverMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# algorithm_type="dpmsolver++",
# use_karras_sigmas=True,
# )
# )
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
# schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# schedulers["KDPM2AncestralDiscrete"] = (
# KDPM2AncestralDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
# )
# schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
# model_id,
# subfolder="scheduler",
# )
return schedulers
def export_scheduler_model(model):
return "None", "None"
scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -0,0 +1,66 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds

View File

@@ -0,0 +1,791 @@
import argparse
import os
from pathlib import Path
from apps.shark_studio.modules.img_processing import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
def is_valid_file(arg):
if not os.path.exists(arg):
return None
else:
return arg
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
default="txt2img",
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
)
p.add_argument(
"-p",
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
],
help="Text of which images to be generated.",
)
p.add_argument(
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
"blurry, ugly, blur, oversaturated, cropped"
],
help="Text you don't want to see in the generated image.",
)
p.add_argument(
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
p.add_argument(
"--steps",
type=int,
default=50,
help="The number of steps to do the sampling.",
)
p.add_argument(
"--seed",
type=str,
default=-1,
help="The seed or list of seeds to use. -1 for a random one.",
)
p.add_argument(
"--batch_size",
type=int,
default=1,
choices=range(1, 4),
help="The number of inferences to be made in a single `batch_count`.",
)
p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The height of the output image.",
)
p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 1025, 8),
help="The width of the output image.",
)
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
help="The value to be used for guidance scaling.",
)
p.add_argument(
"--noise_level",
type=int,
default=20,
help="The value to be used for noise level of upscaler.",
)
p.add_argument(
"--max_length",
type=int,
default=64,
help="Max length of the tokenizer output, options are 64 and 77.",
)
p.add_argument(
"--max_embeddings_multiples",
type=int,
default=5,
help="The max multiple length of prompt embeddings compared to the max "
"output length of text encoder.",
)
p.add_argument(
"--strength",
type=float,
default=0.8,
help="The strength of change applied on the given input image for " "img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
p.add_argument(
"--lora_save_dir",
type=str,
default="models/lora/",
help="Directory to save the lora fine tuned model.",
)
p.add_argument(
"--training_images_dir",
type=str,
default="models/lora/training_images/",
help="Directory containing images that are an example of the prompt.",
)
p.add_argument(
"--training_steps",
type=int,
default=2000,
help="The number of steps to train.",
)
##############################################################################
# Inpainting and Outpainting Params
##############################################################################
p.add_argument(
"--mask_path",
type=str,
help="Path to the mask image input for inpainting.",
)
p.add_argument(
"--inpaint_full_res",
default=False,
action=argparse.BooleanOptionalAction,
help="If inpaint only masked area or whole picture.",
)
p.add_argument(
"--inpaint_full_res_padding",
type=int,
default=32,
choices=range(0, 257, 4),
help="Number of pixels for only masked padding.",
)
p.add_argument(
"--pixels",
type=int,
default=128,
choices=range(8, 257, 8),
help="Number of expended pixels for one direction for outpainting.",
)
p.add_argument(
"--mask_blur",
type=int,
default=8,
choices=range(0, 65),
help="Number of blur pixels for outpainting.",
)
p.add_argument(
"--left",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend left for outpainting.",
)
p.add_argument(
"--right",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend right for outpainting.",
)
p.add_argument(
"--up",
"--top",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend top for outpainting.",
)
p.add_argument(
"--down",
"--bottom",
default=False,
action=argparse.BooleanOptionalAction,
help="If extend bottom for outpainting.",
)
p.add_argument(
"--noise_q",
type=float,
default=1.0,
help="Fall-off exponent for outpainting (lower=higher detail) "
"(min=0.0, max=4.0).",
)
p.add_argument(
"--color_variation",
type=float,
default=0.05,
help="Color variation for outpainting (min=0.0, max=1.0).",
)
##############################################################################
# Model Config and Usage Params
##############################################################################
p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
)
p.add_argument(
"--import_mlir",
default=True,
action=argparse.BooleanOptionalAction,
help="Imports the model from torch module to shark_module otherwise "
"downloads the model from shark_tank.",
)
p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available.",
)
p.add_argument(
"--use_base_vae",
default=False,
action=argparse.BooleanOptionalAction,
help="Do conversion from the VAE output to pixel space on cpu.",
)
p.add_argument(
"--scheduler",
type=str,
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
"HeunDiscrete].",
)
p.add_argument(
"--output_img_format",
type=str,
default="png",
help="Specify the format in which output image is save. "
"Supported options: jpg / png.",
)
p.add_argument(
"--output_dir",
type=str,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)
p.add_argument(
"--batch_count",
type=int,
default=1,
help="Number of batches to be generated with random seeds in " "single execution.",
)
p.add_argument(
"--repeatable_seeds",
default=False,
action=argparse.BooleanOptionalAction,
help="The seed of the first batch will be used as the rng seed to "
"generate the subsequent seeds for subsequent batches in that run.",
)
p.add_argument(
"--custom_weights",
type=str,
default="",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)
p.add_argument(
"--custom_vae",
type=str,
default="",
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
"needs to be plugged in.",
)
p.add_argument(
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
)
p.add_argument(
"--low_cpu_mem_usage",
default=False,
action=argparse.BooleanOptionalAction,
help="Use the accelerate package to reduce cpu memory consumption.",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
"or an integer).",
)
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
default="",
help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="Runs the quantized version of stable diffusion model. "
"This is currently in experimental phase. "
"Currently, only runs the stable-diffusion-2-1-base model in "
"int8 quantization.",
)
p.add_argument(
"--lowvram",
default=False,
action=argparse.BooleanOptionalAction,
help="Load and unload models for low VRAM.",
)
p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--external_weights",
type=str,
default=None,
help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
p.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)
p.add_argument(
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal.",
)
##############################################################################
# Misc. Debug and Optimization flags
##############################################################################
p.add_argument(
"--use_compiled_scheduler",
default=True,
action=argparse.BooleanOptionalAction,
help="Use the default scheduler precompiled into the model if available.",
)
p.add_argument(
"--local_tank_cache",
default="",
help="Specify where to save downloaded shark_tank artifacts. "
"If this is not set, the default is ~/.local/shark_tank/.",
)
p.add_argument(
"--dump_isa",
default=False,
action="store_true",
help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
)
p.add_argument(
"--dispatch_benchmarks",
default=None,
help="Dispatches to return benchmark data on. "
'Use "All" for all, and None for none.',
)
p.add_argument(
"--dispatch_benchmarks_dir",
default="temp_dispatch_benchmarks",
help="Directory where you want to store dispatch data "
'generated with "--dispatch_benchmarks".',
)
p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for inserting debug frames between iterations " "for use with rgp.",
)
p.add_argument(
"--hide_steps",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for hiding the details of iteration/sec for each step.",
)
p.add_argument(
"--warmup_count",
type=int,
default=0,
help="Flag setting warmup count for CLIP and VAE [>= 0].",
)
p.add_argument(
"--clear_all",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to clear all mlir and vmfb from common locations. "
"Recompiling will take several minutes.",
)
p.add_argument(
"--save_metadata_to_json",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save a generation information "
"json file with the image.",
)
p.add_argument(
"--write_metadata_to_png",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for whether or not to save generation information in "
"PNG chunk text to generated images.",
)
p.add_argument(
"--import_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="If import_mlir is True, saves mlir via the debug option "
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
p.add_argument(
"--quantization",
type=str,
default="None",
help="Quantization to be used for api-exposed model.",
)
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--webui",
default=True,
action=argparse.BooleanOptionalAction,
help="controls whether the webui is launched.",
)
p.add_argument(
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the progress bar animation during " "image generation.",
)
p.add_argument(
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)
p.add_argument(
"--config_dir",
type=str,
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)
p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
type=str,
default="app" if os.name == "nt" else "web",
help="One of: [api, app, web].",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for generating a public URL.",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="Flag for setting server port.",
)
p.add_argument(
"--api",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling rest API.",
)
p.add_argument(
"--api_accept_origin",
action="append",
type=str,
help="An origin to be accepted by the REST api for Cross Origin"
"Resource Sharing (CORS). Use multiple times for multiple origins, "
'or use --api_accept_origin="*" to accept all origins. If no origins '
"are set no CORS headers will be returned by the api. Use, for "
"instance, if you need to access the REST api from Javascript running "
"in a web browser.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for removing the output gallery tab, and avoid exposing "
"images under --output_dir in the UI.",
)
p.add_argument(
"--configs_path",
default=None,
type=str,
help="Path to .json config directory.",
)
p.add_argument(
"--output_gallery_followlinks",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for whether the output gallery tab in the UI should "
"follow symlinks when listing subdirectories under --output_dir.",
)
p.add_argument(
"--api_log",
default=False,
action=argparse.BooleanOptionalAction,
help="Enables Compatibility API logging.",
)
##############################################################################
# SD model auto-annotation flags
##############################################################################
p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file.",
)
p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)
p.add_argument(
"--save_annotation",
default=False,
action=argparse.BooleanOptionalAction,
help="Save annotated mlir file.",
)
##############################################################################
# SD model auto-tuner flags
##############################################################################
p.add_argument(
"--tuned_config_dir",
type=path_expand,
default="./",
help="Directory to save the tuned config file.",
)
p.add_argument(
"--num_iters",
type=int,
default=400,
help="Number of iterations for tuning.",
)
p.add_argument(
"--search_op",
type=str,
default="all",
help="Op to be optimized, options are matmul, bmm, conv and all.",
)
##############################################################################
# DocuChat Flags
##############################################################################
p.add_argument(
"--run_docuchat_web",
default=False,
action=argparse.BooleanOptionalAction,
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
cmd_opts, unknown = p.parse_known_args()
if cmd_opts.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
os.getcwd(), cmd_opts.hf_model_id.replace("/", "_")
)

View File

@@ -0,0 +1,106 @@
import time
import argparse
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(
self.original_base_category + self.category,
elapsed_for_subcategroy,
)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer:
def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ""
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
self.records[category] += amount
def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
if self.print_log and not disable_log:
print(
f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s"
)
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self):
res = f"{self.total:.1f}s"
additions = [
(category, time_taken)
for category, time_taken in self.records.items()
if time_taken >= 0.1 and "/" not in category
]
if not additions:
return res
res += " ("
res += ", ".join(
[f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
)
res += ")"
return res
def dump(self):
return {"total": self.total, "records": self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--log-startup",
action="store_true",
help="print a detailed log of what's happening at startup",
)
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None

View File

@@ -0,0 +1,48 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
binaries = []
block_cipher = None
a = Analysis(
['web/index.py'],
pathex=pathex,
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='nodai_shark_studio',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -0,0 +1,68 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
# python path for pyinstaller
pathex = [
".",
]
# datafiles for pyinstaller
datas = []
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += copy_metadata("scipy")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree", include_py_files=True)
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("scipy", include_py_files=True)
datas += [
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
]
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]

View File

@@ -6,27 +6,51 @@
import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel
import json
import gc
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
from apps.shark_studio.web.utils.file_utils import get_resource_path
# class SDAPITest(unittest.TestCase):
# def testSDSimple(self):
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# import apps.shark_studio.web.utils.globals as global_obj
# global_obj._init()
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
# sd_kwargs = json.loads(sd_json)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
# for i in shark_sd_fn_dict_input(sd_kwargs):
# print(i)
class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
def test01_LLMSmall(self):
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"TinyPixel/small-llama2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
device="cpu",
precision="fp32",
quantization="None",
streaming_llm=True,
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Hello"
), f"LLM API failed to return correct response, expected 'Hello', received {msg}"
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
if __name__ == "__main__":

View File

@@ -0,0 +1,41 @@
import torch
from diffusers import (
UNet2DConditionModel,
)
from torch.fx.experimental.proxy_tensor import make_fx
class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
if __name__ == "__main__":
hf_model_name = "CompVis/stable-diffusion-v1-4"
unet = UnetModel(hf_model_name)
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)
fx_g = make_fx(
unet,
decomposition_table={},
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
_allow_fake_constant=False,
)(*inputs)
print(fx_g)

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

View File

@@ -0,0 +1,45 @@
import requests
from PIL import Image
import base64
from io import BytesIO
import json
def llm_chat_test(verbose=False):
# Define values here
prompt = "What is the significance of the number 42?"
url = "http://127.0.0.1:8080/v1/chat/completions"
headers = {
"User-Agent": "PythonTest",
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
}
data = {
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"messages": [
{
"role": "",
"content": prompt,
}
],
"device": "vulkan://0",
"max_tokens": 4096,
}
res = requests.post(url=url, json=data, headers=headers, timeout=1000)
res_dict = json.loads(res.content.decode("utf-8"))
print(f"[chat] response from server was : {res.status_code} {res.reason}")
if verbose or res.status_code != 200:
print(f"\n{res_dict['choices'][0]['message']['content']}\n")
if __name__ == "__main__":
# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
llm_chat_test(verbose=True)

View File

@@ -0,0 +1,286 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import threading
import collections
import gradio as gr
from PIL import Image, PngImagePlugin
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
headers = {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes,
format="PNG",
pnginfo=(metadata if use_metadata else None),
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
class FIFOLock(object):
def __init__(self):
self._lock = threading.Lock()
self._inner_lock = threading.Lock()
self._pending_threads = collections.deque()
def acquire(self, blocking=True):
with self._inner_lock:
lock_acquired = self._lock.acquire(False)
if lock_acquired:
return True
elif not blocking:
return False
release_event = threading.Event()
self._pending_threads.append(release_event)
release_event.wait()
return self._lock.acquire()
def release(self):
with self._inner_lock:
if self._pending_threads:
release_event = self._pending_threads.popleft()
release_event.set()
self._lock.release()
__enter__ = acquire
def __exit__(self, t, v, tb):
self.release()
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get("path", "err")
if cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
)
)
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
print(message)
raise (e)
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class ApiCompat:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
# self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
# self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
# self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
# self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
# self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
# self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
# self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
# self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
# self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
# self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
# self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
# self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
# self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
# self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
# self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
# self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
# self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
# self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
# self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
# self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
# self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
# self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
# self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
# self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
# self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"])
self.add_api_route("/completions", llm_chat_api, methods=["POST"])
self.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["POST"]
)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
# def refresh_checkpoints(self):
# with self.queue_lock:
# studio_data.refresh_checkpoints()
# def refresh_vae(self):
# with self.queue_lock:
# studio_data.refresh_vae_list()
# def unloadapi(self):
# unload_model_weights()
# return {}
# def reloadapi(self):
# reload_model_weights()
# return {}
# def skip(self):
# studio.state.skip()
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(
self.app,
host=server_name,
port=port,
root_path=root_path,
)
# def kill_studio(self):
# restart.stop_program()
# def restart_studio(self):
# if restart.is_restartable():
# restart.restart_program()
# return Response(status_code=501)
# def preprocess(self, args: dict):
# try:
# studio.state.begin(job="preprocess")
# preprocess(**args)
# studio.state.end()
# return models.PreprocessResponse(info="preprocess complete")
# except:
# studio.state.end()
# def stop_studio(request):
# studio.state.server_command = "stop"
# return Response("Stopping.")

View File

@@ -0,0 +1 @@

View File

@@ -1,20 +1,59 @@
from multiprocessing import Process, freeze_support
freeze_support()
from PIL import Image
import os
import time
import sys
import logging
from ui.chat import chat_element
import apps.shark_studio.api.initializers as initialize
from apps.shark_studio.modules import timer
startup_timer = timer.startup_timer
startup_timer.record("launcher")
initialize.imports()
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock
queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
return api
def launch_app(address):
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
api = create_api(app)
# from modules import script_callbacks
# script_callbacks.before_ui_callback()
# script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0",
port=cmd_opts.server_port,
root_path="",
)
def launch_webui(address):
from tkinter import Tk
import webview
@@ -34,138 +73,78 @@ def launch_app(address):
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdicon_loc,
amdlogo_loc,
)
launch_api = cmd_opts.api
initialize.initialize()
from ui.chat import chat_element
from ui.sd import sd_element
from ui.outputgallery import outputgallery_element
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# api.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
# api.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# api.include_router(APIRouter())
#
# # deal with CORS requests if CORS accept origins are set
# if args.api_accept_origin:
# print(
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
# )
# api.add_middleware(
# CORSMiddleware,
# allow_origins=args.api_accept_origin,
# allow_methods=["GET", "POST"],
# allow_headers=["*"],
# )
# else:
# print("API not configured for CORS")
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
# from apps.shark_studio.web.ui import load_ui_from_script
def register_button_click(button, selectedid, inputs, outputs):
button.click(
@@ -177,17 +156,6 @@ if __name__ == "__main__":
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
@@ -199,8 +167,19 @@ if __name__ == "__main__":
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
) as sd_web:
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0 Beta",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(
value=amd_logo,
show_label=False,
interactive=False,
elem_id="tab_bar_logo",
show_download_button=False,
)
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
@@ -211,216 +190,33 @@ if __name__ == "__main__":
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
with gr.TabItem(label="Stable Diffusion", id=0):
sd_element.render()
with gr.TabItem(label="Output Gallery", id=1):
outputgallery_element.render()
with gr.TabItem(label="Chat Bot", id=2):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
studio_web.queue()
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
studio_web.launch(
share=cmd_opts.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
server_port=cmd_opts.server_port,
favicon_path=amdicon_loc,
)
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
if cmd_opts.webui == False:
api_only()
else:
webui()

View File

@@ -5,13 +5,18 @@ from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
B_SYS, E_SYS = "<s>", "</s>"
def user(message, history):
@@ -19,13 +24,15 @@ def user(message, history):
return "", history + [[message, ""]]
def append_bot_prompt(history, input_prompt):
user_prompt = f"{input_prompt} {E_SYS} {E_SYS}"
history += user_prompt
return history
language_model = None
def create_prompt(model_name, history, prompt_prefix):
return ""
def get_default_config():
return False
@@ -41,9 +48,13 @@ def chat_fn(
precision,
download_vmfb,
config_file,
streaming_llm,
cli=False,
):
global language_model
if streaming_llm and prompt_prefix == "Clear":
language_model = None
return "Clearing history...", ""
if language_model is None:
history[-1][-1] = "Getting the model ready..."
yield history, ""
@@ -52,8 +63,9 @@ def chat_fn(
device=device,
precision=precision,
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
)
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
@@ -63,7 +75,7 @@ def chat_fn(
prefill_time = 0
is_first = True
for text, exec_time in language_model.chat(history):
history[-1][-1] = text
history[-1][-1] = f"{text}{E_SYS}"
if is_first:
prefill_time = exec_time
is_first = False
@@ -75,101 +87,6 @@ def chat_fn(
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
@@ -186,7 +103,7 @@ with gr.Blocks(title="Chat") as chat_element:
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
supported_devices = global_obj.get_device_list()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
@@ -200,7 +117,7 @@ with gr.Blocks(title="Chat") as chat_element:
)
precision = gr.Radio(
label="Precision",
value="int4",
value="fp32",
choices=[
# "int4",
# "int8",
@@ -213,12 +130,19 @@ with gr.Blocks(title="Chat") as chat_element:
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
value=False,
interactive=True,
visible=False,
)
streaming_llm = gr.Checkbox(
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
value=True,
interactive=True,
)
@@ -241,8 +165,8 @@ with gr.Blocks(title="Chat") as chat_element:
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
@@ -262,6 +186,7 @@ with gr.Blocks(title="Chat") as chat_element:
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
@@ -283,6 +208,7 @@ with gr.Blocks(title="Chat") as chat_element:
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
@@ -295,4 +221,19 @@ with gr.Blocks(title="Chat") as chat_element:
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
clear.click(
fn=chat_fn,
inputs=[
clear,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
streaming_llm,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
).then(lambda: None, None, [chatbot], queue=False)

View File

@@ -0,0 +1,67 @@
from apps.shark_studio.web.ui.utils import (
HSLHue,
hsl_color,
)
from apps.shark_studio.modules.embeddings import get_lora_metadata
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_files):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = (
'<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
)
output = []
for lora_file in lora_files:
if lora_file == "":
output.extend(["<div><i>No LoRA selected</i></div>"])
elif not lora_file.lower().endswith(".safetensors"):
output.extend(
[
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
)
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
output.extend(
[
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
)
elif metadata is None:
output.extend(
[
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
)
else:
output.extend(
[
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
)
return output

View File

@@ -0,0 +1,373 @@
/*
Apply Gradio dark theme to the default Gradio theme.
Procedure to upgrade the dark theme:
- Using your browser, visit http://localhost:8080/?__theme=dark
- Open your browser inspector, search for the .dark css class
- Copy .dark class declarations, apply them here into :root
*/
:root {
--body-background-fill: var(--background-fill-primary);
--body-text-color: var(--neutral-100);
--color-accent-soft: var(--neutral-700);
--background-fill-primary: var(--neutral-950);
--background-fill-secondary: var(--neutral-900);
--border-color-accent: var(--neutral-600);
--border-color-primary: var(--neutral-700);
--link-text-color-active: var(--secondary-500);
--link-text-color: var(--secondary-500);
--link-text-color-hover: var(--secondary-400);
--link-text-color-visited: var(--secondary-600);
--body-text-color-subdued: var(--neutral-400);
--shadow-spread: 1px;
--block-background-fill: var(--neutral-800);
--block-border-color: var(--border-color-primary);
--block_border_width: None;
--block-info-text-color: var(--body-text-color-subdued);
--block-label-background-fill: var(--background-fill-secondary);
--block-label-border-color: var(--border-color-primary);
--block_label_border_width: None;
--block-label-text-color: var(--neutral-200);
--block_shadow: None;
--block_title_background_fill: None;
--block_title_border_color: None;
--block_title_border_width: None;
--block-title-text-color: var(--neutral-200);
--panel-background-fill: var(--background-fill-secondary);
--panel-border-color: var(--border-color-primary);
--panel_border_width: None;
--checkbox-background-color: var(--neutral-800);
--checkbox-background-color-focus: var(--checkbox-background-color);
--checkbox-background-color-hover: var(--checkbox-background-color);
--checkbox-background-color-selected: var(--secondary-600);
--checkbox-border-color: var(--neutral-700);
--checkbox-border-color-focus: var(--secondary-500);
--checkbox-border-color-hover: var(--neutral-600);
--checkbox-border-color-selected: var(--secondary-600);
--checkbox-border-width: var(--input-border-width);
--checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800));
--checkbox-label-background-fill-selected: var(--checkbox-label-background-fill);
--checkbox-label-border-color: var(--border-color-primary);
--checkbox-label-border-color-hover: var(--checkbox-label-border-color);
--checkbox-label-border-width: var(--input-border-width);
--checkbox-label-text-color: var(--body-text-color);
--checkbox-label-text-color-selected: var(--checkbox-label-text-color);
--error-background-fill: var(--background-fill-primary);
--error-border-color: var(--border-color-primary);
--error_border_width: None;
--error-text-color: #ef4444;
--input-background-fill: var(--neutral-800);
--input-background-fill-focus: var(--secondary-600);
--input-background-fill-hover: var(--input-background-fill);
--input-border-color: var(--border-color-primary);
--input-border-color-focus: var(--neutral-700);
--input-border-color-hover: var(--input-border-color);
--input_border_width: None;
--input-placeholder-color: var(--neutral-500);
--input_shadow: None;
--input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset);
--loader_color: None;
--slider_color: None;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
--table-even-background-fill: var(--neutral-950);
--table-odd-background-fill: var(--neutral-900);
--table-row-focus: var(--color-accent-soft);
--button-border-width: var(--input-border-width);
--button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c);
--button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626);
--button-cancel-border-color: #dc2626;
--button-cancel-border-color-hover: var(--button-cancel-border-color);
--button-cancel-text-color: white;
--button-cancel-text-color-hover: var(--button-cancel-text-color);
--button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600));
--button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500));
--button-primary-border-color: var(--primary-500);
--button-primary-border-color-hover: var(--button-primary-border-color);
--button-primary-text-color: white;
--button-primary-text-color-hover: var(--button-primary-text-color);
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600));
--button-secondary-border-color: var(--neutral-600);
--button-secondary-border-color-hover: var(--button-secondary-border-color);
--button-secondary-text-color: white;
--button-secondary-text-color-hover: var(--button-secondary-text-color);
--block-border-width: 1px;
--block-label-border-width: 1px;
--form-gap-width: 1px;
--error-border-width: 1px;
--input-border-width: 1px;
}
/* SHARK theme */
body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices, but see below */
@media (min-width: 1536px)
{
.gradio-container {
max-width: var(--size-full) !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define a class which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gradio-container-size-full {
max-width: var(--size-full) !important;
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
#top_logo {
color: transparent;
background-color: transparent;
border-radius: 0 !important;
border: 0;
}
#ui_title {
padding: var(--size-2) 0 0 var(--size-1);
}
#demo_title_outer {
border-radius: 0;
}
#prompt_box_outer div:first-child {
border-radius: 0 !important
}
#prompt_box textarea, #negative_prompt_box textarea {
background-color: var(--background-fill-primary) !important;
}
#prompt_examples {
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
#ui_body {
padding: var(--size-2) !important;
border-radius: 0.5em !important;
}
#img_result+div {
display: none !important;
}
footer {
display: none !important;
}
#gallery + div {
border-radius: 0 !important;
}
/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */
#gallery .thumbnail-item.thumbnail-lg {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
/* fix width and height of gallery items when on very large desktop screens, but see below */
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
min-height: calc(768px + 4px + var(--size-14));
max-height: calc(768px + 4px + var(--size-14));
}
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
#gallery .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define classes which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
min-height: calc(768px + 4px + var(--size-14)) !important;
max-height: calc(768px + 4px + var(--size-14)) !important;
}
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
}
/* Navbar images in cover mode*/
#gallery .preview .thumbnail-item img {
object-fit: cover;
}
/* Limit the stable diffusion text output height */
#std_output textarea {
max-height: 215px;
}
/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */
#gallery .wrap.default {
pointer-events: none;
}
/* Import Png info box */
#txt2img_prompt_image {
height: var(--size-32) !important;
}
/* Hide "remove buttons" from ui dropdowns */
#custom_model .token-remove.remove-all,
#lora_weights .token-remove.remove-all,
#scheduler .token-remove.remove-all,
#device .token-remove.remove-all,
#stencil_model .token-remove.remove-all {
display: none;
}
/* Hide selected items from ui dropdowns */
#custom_model .options .item .inner-item,
#scheduler .options .item .inner-item,
#device .options .item .inner-item,
#stencil_model .options .item .inner-item {
display:none;
}
/* workarounds for container=false not currently working for dropdowns */
.dropdown_no_container {
padding: 0 !important;
}
#output_subdir_container :first-child {
border: none;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
}
/* better clarity when progress bars are minimal */
.meta-text {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs);
}
.output_icon_button {
max-width: 30px;
align-self: end;
padding-bottom: 8px;
}
.outputgallery_sendto {
min-width: 7em !important;
}
/* output gallery should take up most of the viewport height regardless of image size/number */
#outputgallery_gallery .fixed-height {
min-height: 89vh !important;
}
.sd-right-panel {
height: calc(100vmin - var(--size-32) - var(--size-10)) !important;
overflow-y: scroll;
}
.sd-right-panel .fill {
flex: 1;
}
/* don't stretch non-square images to be square, breaking their aspect ratio */
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
object-fit: contain !important;
}
/* centered logo for when there are no images */
#top_logo.logo_centered {
height: 100%;
width: 100%;
}
#top_logo.logo_centered img {
object-fit: scale-down;
position: absolute;
width: 80%;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
#tab_bar_logo {
overflow: visible !important;
border-width: 0 !important;
height: 0px !important;
padding: 0;
margin: 0;
}
#tab_bar_logo .image-container {
object-fit: scale-down;
position: absolute !important;
top: 10px;
right: 0px;
height: 36px;
}

View File

@@ -0,0 +1,49 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -0,0 +1,406 @@
import glob
import gradio as gr
import os
import subprocess
import sys
from PIL import Image
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
from apps.shark_studio.web.ui.utils import amdlogo_loc
from apps.shark_studio.web.utils.metadata import displayable_metadata
# -- Functions for file, directory and image info querying
output_dir = get_generated_imgs_path()
def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
else:
return []
def output_subdirs() -> list[str]:
# Gets a list of subdirectories of output_dir and below, as relative paths.
relative_paths = [
os.path.relpath(entry[0], output_dir)
for entry in os.walk(
output_dir, followlinks=cmd_opts.output_gallery_followlinks
)
]
# It is less confusing to always including the subdir that will take any
# images generated today even if it doesn't exist yet
if get_generated_imgs_todays_subdir() not in relative_paths:
relative_paths.append(get_generated_imgs_todays_subdir())
# sort subdirectories so that the date named ones we probably
# created in this or previous sessions come first, sorted with the most
# recent first. Other subdirs are listed after.
generated_paths = sorted(
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
[path for path in relative_paths if (not path.isnumeric()) and path != "."]
)
return result_paths
# --- Define UI layout for Gradio
with gr.Blocks() as outputgallery_element:
amd_logo = Image.open(amdlogo_loc)
with gr.Row(elem_id="outputgallery_gallery"):
# needed to workaround gradio issue:
# https://github.com/gradio-app/gradio/issues/2907
dev_null = gr.Textbox("", visible=False)
gallery_files = gr.State(value=[])
subdirectory_paths = gr.State(value=[])
with gr.Column(scale=6):
logo = gr.Image(
label="Getting subdirectories...",
value=amd_logo,
interactive=False,
visible=True,
show_label=True,
elem_id="top_logo",
elem_classes="logo_centered",
show_download_button=False,
)
gallery = gr.Gallery(
label="",
value=gallery_files.value,
visible=False,
show_label=True,
columns=4,
)
with gr.Column(scale=4):
with gr.Group():
with gr.Row():
with gr.Column(
scale=15,
min_width=160,
elem_id="output_subdir_container",
):
subdirectories = gr.Dropdown(
label=f"Subdirectories of {output_dir}",
type="value",
choices=subdirectory_paths.value,
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
open_subdir = gr.Button(
variant="secondary",
value="\U0001F5C1", # unicode open folder
interactive=False,
size="sm",
)
with gr.Column(
scale=1,
min_width=32,
elem_classes="output_icon_button",
):
refresh = gr.Button(
variant="secondary",
value="\u21BB", # unicode clockwise arrow circle
size="sm",
)
image_columns = gr.Slider(
label="Columns shown", value=4, minimum=1, maximum=16, step=1
)
outputgallery_filename = gr.Textbox(
label="Filename",
value="None",
interactive=False,
show_copy_button=True,
)
with gr.Accordion(
label="Parameter Information", open=False
) as parameters_accordian:
image_parameters = gr.DataFrame(
headers=["Parameter", "Value"],
col_count=2,
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)
with gr.Accordion(label="Send To", open=True):
with gr.Row():
outputgallery_sendto_sd = gr.Button(
value="Stable Diffusion",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)
# --- Event handlers
def on_clear_gallery():
return [
gr.Gallery(
value=[],
visible=False,
),
gr.Image(
visible=True,
),
]
def on_image_columns_change(columns):
return gr.Gallery(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_open_subdir(subdir):
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
if os.path.isdir(subdir_path):
if sys.platform == "linux":
subprocess.run(["xdg-open", subdir_path])
elif sys.platform == "darwin":
subprocess.run(["open", subdir_path])
elif sys.platform == "win32":
os.startfile(subdir_path)
def on_refresh(current_subdir: str) -> list:
# get an up-to-date subdirectory list
refreshed_subdirs = output_subdirs()
# get the images using either the current subdirectory or the most
# recent valid one
new_subdir = (
current_subdir
if current_subdir in refreshed_subdirs
else refreshed_subdirs[0]
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}"
)
return [
gr.Dropdown(
choices=refreshed_subdirs,
value=new_subdir,
),
refreshed_subdirs,
new_images,
gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
def on_new_image(subdir, subdir_paths, status) -> list:
# prevent error triggered when an image generates before the tab
# has even been selected
subdir_paths = (
subdir_paths
if len(subdir_paths) > 0
else [get_generated_imgs_todays_subdir()]
)
# only update if the current subdir is the most recent one as
# new images only go there
if subdir_paths[0] == subdir:
new_images = outputgallery_filenames(subdir)
new_label = (
f"{len(new_images)} images in "
f"{os.path.join(output_dir, subdir)} - {status}"
)
return [
new_images,
gr.Gallery(
value=new_images,
label=new_label,
visible=len(new_images) > 0,
),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
),
]
else:
# otherwise change nothing,
# (only untyped gradio gr.update() does this)
return [gr.update(), gr.update(), gr.update()]
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
# evt.index is an index into the full list of filenames for
# the current subdirectory
filename = images[evt.index]
params = displayable_metadata(filename)
if params:
if params["source"] == "missing":
return [
"Could not find this image file, refresh the gallery and update the images",
[["Status", "File missing"]],
]
else:
return [
filename,
list(map(list, params["parameters"].items())),
]
return [
filename,
[["Status", "No parameters found"]],
]
def on_outputgallery_filename_change(filename: str) -> list:
exists = filename != "None" and os.path.exists(filename)
return [
# disable or enable each of the sendto button based on whether
# an image is selected
gr.Button(interactive=exists),
]
# The time first our tab is selected we need to do an initial refresh
# to populate the subdirectory select box and the images from the most
# recent subdirectory.
#
# We do it at this point rather than setting this up in the controls'
# definitions as when you refresh the browser you always get what was
# *initially* set, which won't include any new subdirectories or images
# that might have created since the application was started. Doing it
# this way means a browser refresh/reload always gets the most
# up-to-date data.
def on_select_tab(subdir_paths, request: gr.Request):
local_client = request.headers["host"].startswith(
"127.0.0.1:"
) or request.headers["host"].startswith("localhost:")
if len(subdir_paths) == 0:
return on_refresh("") + [gr.update(interactive=local_client)]
else:
return (
# Change nothing, (only untyped gr.update() does this)
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
# replacement phase.
clear_gallery = dict(
fn=on_clear_gallery,
inputs=None,
outputs=[gallery, logo],
queue=False,
)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
)
outputgallery_filename.change(
on_outputgallery_filename_change,
[outputgallery_filename],
[
outputgallery_sendto_sd,
],
queue=False,
)
# We should have been given the .select function for our tab, so set it up
def outputgallery_tab_select(select):
select(
fn=on_select_tab,
inputs=[subdirectory_paths],
outputs=[
subdirectories,
subdirectory_paths,
gallery_files,
gallery,
logo,
open_subdir,
],
queue=False,
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
# will see that new image if they are looking at today's subdirectory
def outputgallery_watch(components: gr.Textbox):
for component in components:
component.change(
on_new_image,
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
)

View File

@@ -0,0 +1,777 @@
import os
import json
import gradio as gr
import numpy as np
from inspect import signature
from PIL import Image
from pathlib import Path
from datetime import datetime as dt
from gradio.components.image_editor import (
EditorValue,
)
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
)
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
from apps.shark_studio.modules.img_processing import (
resampler_list,
resize_stencil,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.web.ui.utils import (
amdlogo_loc,
none_to_str_none,
str_none_to_none,
)
from apps.shark_studio.web.utils.state import (
status_label,
)
from apps.shark_studio.web.ui.common_events import lora_changed
from apps.shark_studio.modules import logger
import apps.shark_studio.web.utils.globals as global_obj
sd_default_models = [
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def submit_to_cnet_config(
stencil: str,
preprocessed_hint: str,
cnet_strength: int,
control_mode: str,
curr_config: dict,
):
if any(i in [None, ""] for i in [stencil, preprocessed_hint]):
return gr.update()
if curr_config is not None:
if "controlnets" in curr_config:
curr_config["controlnets"]["control_mode"] = control_mode
curr_config["controlnets"]["model"].append(stencil)
curr_config["controlnets"]["hint"].append(preprocessed_hint)
curr_config["controlnets"]["strength"].append(cnet_strength)
return curr_config
cnet_map = {}
cnet_map["controlnets"] = {
"control_mode": control_mode,
"model": [stencil],
"hint": [preprocessed_hint],
"strength": [cnet_strength],
}
return cnet_map
def update_embeddings_json(embedding):
return {"embeddings": [embedding]}
def submit_to_main_config(input_cfg: dict, main_cfg: dict):
if main_cfg in [None, "", {}]:
return input_cfg
for base_key in input_cfg:
main_cfg[base_key] = input_cfg[base_key]
return main_cfg
def pull_sd_configs(
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
):
sd_args = str_none_to_none(locals())
sd_cfg = {}
for arg in sd_args:
if arg in [
"prompt",
"negative_prompt",
"sd_init_image",
]:
sd_cfg[arg] = [sd_args[arg]]
elif arg in ["controlnets", "embeddings"]:
if isinstance(arg, dict):
sd_cfg[arg] = json.loads(sd_args[arg])
else:
sd_cfg[arg] = {}
else:
sd_cfg[arg] = sd_args[arg]
return json.dumps(sd_cfg)
def load_sd_cfg(sd_json: dict, load_sd_config: str):
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
else:
sd_json = new_sd_config
for i in sd_json["sd_init_image"]:
if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
sd_image = None
return [
sd_json["prompt"][0],
sd_json["negative_prompt"][0],
sd_image,
sd_json["height"],
sd_json["width"],
sd_json["steps"],
sd_json["strength"],
sd_json["guidance_scale"],
sd_json["seed"],
sd_json["batch_count"],
sd_json["batch_size"],
sd_json["scheduler"],
sd_json["base_model_id"],
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json,
]
def save_sd_cfg(config: dict, save_name: str):
if os.path.exists(save_name):
filepath = save_name
elif cmd_opts.configs_path:
filepath = os.path.join(cmd_opts.configs_path, save_name)
else:
filepath = os.path.join(get_configs_path(), save_name)
if ".json" not in filepath:
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return "..."
def create_canvas(width, height):
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
img_dict = {
"background": data,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def import_original(original_img, width, height):
if original_img is None:
resized_img = create_canvas(width, height)
return resized_img
else:
resized_img, _, _ = resize_stencil(original_img, width, height)
img_dict = {
"background": resized_img,
"layers": [],
"composite": None,
}
return EditorValue(img_dict)
def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
)
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Device Settings", open=False
):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=global_obj.get_device_list()[0],
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
384,
1024,
value=cmd_opts.height,
step=8,
label="\U00002195\U0000FE0F Height",
)
width = gr.Slider(
384,
1024,
value=cmd_opts.width,
step=8,
label="\U00002194\U0000FE0F Width",
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights", open=False
):
with gr.Column():
custom_weights = gr.Dropdown(
label="Checkpoint Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights],
)
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
)
sd_vae_info = f"VAE Path: {sd_vae_info}"
custom_vae = gr.Dropdown(
label=f"VAE Model",
info=sd_vae_info,
elem_id="custom_model",
value=(
os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None"
),
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
)
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
"\\", "\n\\"
)
lora_opt = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=None,
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
)
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
embeddings_config = gr.JSON(
label="Embeddings Options", min_width=50, scale=1
)
gr.on(
triggers=[lora_opt.change],
fn=lora_changed,
inputs=[lora_opt],
outputs=[lora_tags],
queue=True,
show_progress=False,
).then(
fn=update_embeddings_json,
inputs=[lora_opt],
outputs=[embeddings_config],
show_progress=False,
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
):
strength = gr.Slider(
0,
1,
value=cmd_opts.strength,
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=cmd_opts.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=cmd_opts.prompt[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=cmd_opts.negative_prompt[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer or a JSON list of integers, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Slider(
1,
100,
value=cmd_opts.steps,
step=1,
label="\U0001F3C3\U0000FE0F Steps",
)
guidance_scale = gr.Slider(
0,
50,
value=cmd_opts.guidance_scale,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Accordion(
label="Controlnet Options",
open=False,
visible=False,
):
preprocessed_hints = gr.State([])
with gr.Column():
sd_cnet_info = (
str(get_checkpoints_path("controlnet"))
).replace("\\", "\n\\")
with gr.Row():
cnet_config = gr.JSON()
with gr.Column():
clear_config = gr.ClearButton(
value="Clear Controlnet Config",
size="sm",
components=cnet_config,
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Row():
with gr.Column(scale=1):
cnet_model = gr.Dropdown(
allow_custom_value=True,
label=f"Controlnet Model",
info=sd_cnet_info,
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
]
+ get_checkpoints("controlnet"),
)
cnet_strength = gr.Slider(
label="Controlnet Strength",
minimum=0,
maximum=100,
value=50,
step=1,
)
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=8,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=8,
)
make_canvas = gr.Button(
value="Make Canvas!",
)
use_input_img = gr.Button(
value="Use Original Image",
size="sm",
)
cnet_input = gr.Image(
value=None,
type="pil",
image_mode="RGB",
interactive=True,
)
with gr.Column(scale=1):
cnet_output = gr.Image(
value=None,
visible=True,
label="Preprocessed Hint",
interactive=False,
show_label=True,
)
cnet_gen = gr.Button(
value="Preprocess controlnet input",
)
use_result = gr.Button(
"Submit",
size="sm",
)
make_canvas.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[cnet_input],
queue=False,
)
cnet_gen.click(
fn=cnet_preview,
inputs=[
cnet_model,
cnet_input,
],
outputs=[
cnet_output,
preprocessed_hints,
],
)
use_result.click(
fn=submit_to_cnet_config,
inputs=[
cnet_model,
cnet_output,
cnet_strength,
control_mode,
cnet_config,
],
outputs=[
cnet_config,
],
queue=False,
)
with gr.Column(scale=3, min_width=600):
with gr.Tabs() as sd_tabs:
sd_element.load(
# Workaround for Gradio issue #7085
# TODO: revert to setting selected= in gr.Tabs declaration
# once this is resolved in Gradio
lambda: gr.Tabs(selected=101),
outputs=[sd_tabs],
)
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
# TODO: make this import image prompt info if it exists
sd_init_image = gr.Image(
type="pil",
interactive=True,
show_label=False,
)
use_input_img.click(
fn=import_original,
inputs=[
sd_init_image,
canvas_width,
canvas_height,
],
outputs=[cnet_input],
queue=False,
)
with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
sd_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
object_fit="fit",
preview=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=True,
)
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_configs(get_configs_path())
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root_dir=(
cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
],
outputs=[
sd_json,
],
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=sd_status,
)
gen_kwargs = dict(
fn=shark_sd_fn_dict_input,
inputs=[sd_json],
outputs=[
sd_gallery,
sd_status,
],
)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
generate_click = (
stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -0,0 +1,43 @@
from enum import IntEnum
import math
import sys
import os
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
amdlogo_loc = resource_path("logos/amd-logo.jpg")
amdicon_loc = resource_path("logos/amd-icon.jpg")
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def none_to_str_none(props: dict):
for key in props:
props[key] = "None" if props[key] == None else props[key]
return props
def str_none_to_none(props: dict):
for key in props:
props[key] = None if props[key] == "None" else props[key]
return props

View File

@@ -0,0 +1,12 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

View File

View File

@@ -0,0 +1,95 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""
default_sd_configs = {
"default_sd_config.json": default_sd_config,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}

View File

@@ -0,0 +1,102 @@
import os
import sys
import glob
from datetime import datetime as dt
from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
checkpoints_filetypes = (
"*.ckpt",
"*.safetensors",
)
from apps.shark_studio.web.utils.default_configs import default_sd_configs
def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])
def safe_name(name):
return name.split("/")[-1].replace("-", "_")
def get_path_stem(path):
path = Path(path)
return path.stem
def get_resource_path(path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
if os.path.isabs(path):
return path
else:
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, path)).resolve(strict=False)
return result
def get_configs_path() -> Path:
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(configs)
def get_generated_imgs_path() -> Path:
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)
def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)
def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)
for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

View File

@@ -0,0 +1,134 @@
import gc
from ...api.utils import get_available_devices
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def _init():
global _sd_obj
global _llm_obj
global _devices
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
_sd_obj = None
_llm_obj = None
_devices = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None
set_devices()
def set_sd_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = None
_sd_obj = value
def set_llm_obj(value):
global _sd_obj
global _llm_obj
_llm_obj = value
_sd_obj = None
def set_devices():
global _devices
_devices = get_available_devices()
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
global _sd_obj
_sd_obj.status = value
def set_pipe_kwargs(value):
global _pipe_kwargs
_pipe_kwargs = value
def set_prep_kwargs(value):
global _prep_kwargs
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
global _sd_obj
return _sd_obj
def get_llm_obj():
global _llm_obj
return _llm_obj
def get_device_list():
global _devices
return _devices
def get_sd_status():
global _sd_obj
return _sd_obj.status
def get_pipe_kwargs():
global _pipe_kwargs
return _pipe_kwargs
def get_prep_kwargs():
global _prep_kwargs
return _prep_kwargs
def get_gen_kwargs():
global _gen_kwargs
return _gen_kwargs
def get_scheduler(key):
global _schedulers
return _schedulers[key]
def clear_cache():
global _sd_obj
global _llm_obj
global _pipe_kwargs
global _prep_kwargs
global _gen_kwargs
global _schedulers
del _sd_obj
del _llm_obj
del _schedulers
gc.collect()
_sd_obj = None
_llm_obj = None
_pipe_kwargs = None
_prep_kwargs = None
_gen_kwargs = None
_schedulers = None

View File

@@ -0,0 +1,6 @@
from .png_metadata import (
import_png_metadata,
)
from .display import (
displayable_metadata,
)

View File

@@ -0,0 +1,43 @@
import csv
import os
from .format import humanize, humanizable
def csv_path(image_filename: str):
return os.path.join(os.path.dirname(image_filename), "imgs_details.csv")
def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -0,0 +1,53 @@
import json
import os
from PIL import Image
from .png_metadata import parse_generation_parameters
from .exif_metadata import has_exif, parse_exif
from .csv_metadata import has_csv, parse_csv
from .format import compact, humanize
def displayable_metadata(image_filename: str) -> dict:
if not os.path.isfile(image_filename):
return {"source": "missing", "parameters": {}}
pil_image = Image.open(image_filename)
# we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads,
# and we go via that for SendTo, and is directly tied to the image)
if "parameters" in pil_image.info:
return {
"source": "png",
"parameters": compact(
parse_generation_parameters(pil_image.info["parameters"])
),
}
# we have a matching json file (next most likely to be accurate when it's there)
json_path = os.path.splitext(image_filename)[0] + ".json"
if os.path.isfile(json_path):
with open(json_path) as params_file:
return {
"source": "json",
"parameters": compact(
humanize(json.load(params_file), includes_filename=False)
),
}
# we have a CSV file so try that (can be different shapes, and it usually has no
# headers/param names so of the things we we *know* have parameters, it's the
# last resort)
if has_csv(image_filename):
params = parse_csv(image_filename)
if params: # we might not have found the filename in the csv
return {
"source": "csv",
"parameters": compact(params), # already humanized
}
# EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something*
if has_exif(image_filename):
return {"source": "exif", "parameters": parse_exif(pil_image)}
# we've got nothing
return None

View File

@@ -0,0 +1,52 @@
from PIL import Image
from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS
def has_exif(image_filename: str) -> bool:
return True if Image.open(image_filename).getexif() else False
def parse_exif(pil_image: Image) -> dict:
img_exif = pil_image.getexif()
# See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594
# I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I
# I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a
# dependency
exif_tags = {
TAGS.get(key, key): str(val)
for (key, val) in img_exif.items()
if key in TAGS
and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo)
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
def try_get_ifd(ifd_id):
try:
return img_exif.get_ifd(ifd_id).items()
except KeyError:
return {}
ifd_tags = {
TAGS.get(key, key): str(val)
for ifd_id in IFD
for (key, val) in try_get_ifd(ifd_id)
if ifd_id != IFD.GPSInfo
and key in TAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
gps_tags = {
GPSTAGS.get(key, key): str(val)
for (key, val) in try_get_ifd(IFD.GPSInfo)
if key in GPSTAGS
and val
and (not isinstance(val, bytes))
and (not str(val).isspace())
}
return {**exif_tags, **ifd_tags, **gps_tags}

View File

@@ -0,0 +1,139 @@
# As SHARK has evolved more columns have been added to images_details.csv. However, since
# no version of the CSV has any headers (yet) we don't actually have anything within the
# file that tells us which parameter each column is for. So this is a list of known patterns
# indexed by length which is what we're going to have to use to guess which columns are the
# right ones for the file we're looking at.
# The same ordering is used for JSON, but these do have key names, however they are not very
# human friendly, nor do they match up with the what is written to the .png headers
# So these are functions to try and get something consistent out the raw input from all
# these sources
PARAMS_FORMATS = {
9: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
10: {
"MODEL": "Model",
"VARIANT": "Variant",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"OUTPUT": "Filename",
},
12: {
"VARIANT": "Model",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
},
}
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
# we don't want to alter the original dictionary
result = dict(metadata)
# discard the filename because we should already have it
if result.keys() & {"Filename"}:
result.pop("Filename")
# make showing the sizes more compact by using only one line each
if result.keys() & {"Size-1", "Size-2"}:
result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}"
elif result.keys() & {"Height", "Width"}:
result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}"
if result.keys() & {"Hires resize-1", "Hires resize-1"}:
hires_y = result.pop("Hires resize-1")
hires_x = result.pop("Hires resize-2")
if hires_x == 0 and hires_y == 0:
result["Hires resize"] = "None"
else:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"):
result.pop("VAE")
# remove LoRA if it exists and is empty
if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"):
result.pop("LoRA")
return result
def humanizable(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
return lookup_key in PARAMS_FORMATS.keys()
def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
lookup_key = len(metadata) + (0 if includes_filename else 1)
# For lists we can only work based on the length, we have no other information
if isinstance(metadata, list):
if humanizable(metadata, includes_filename):
return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata))
else:
raise KeyError(
f"Humanize could not find the format for a parameter list of length {len(metadata)}"
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_CURRENT
return {
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -0,0 +1,216 @@
import re
from pathlib import Path
from apps.shark_studio.web.utils.file_utils import (
get_checkpoint_pathfile,
)
from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map
from apps.shark_studio.modules.schedulers import (
scheduler_model_map,
)
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
def parse_generation_parameters(x: str):
res = {}
prompt = ""
negative_prompt = ""
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ""
for i, line in enumerate(lines):
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k + "-1"] = m.group(1)
res[k + "-2"] = m.group(2)
else:
res[k] = v
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
res["Clip skip"] = "1"
hypernet = res.get("Hypernet", None)
if hypernet is not None:
res[
"Prompt"
] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
return res
def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str:
custom = ""
# Remove extension from file info
if file.endswith(".safetensors") or file.endswith(".ckpt"):
file = Path(file).stem
# Check for the file name match with one of the local ckpt or safetensors files
if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file():
custom = file + ".ckpt"
if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file():
custom = file + ".safetensors"
return custom
def find_model_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
png_hf_id = ""
png_custom = ""
if key in metadata:
model_file = metadata[key]
png_custom = try_find_model_base_from_png_metadata(model_file)
# Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0")
if model_file in sd_model_map:
png_custom = model_file
# If nothing had matched, check vendor/hf_model_id
if not png_custom and model_file.count("/"):
png_hf_id = model_file
# No matching model was found
if not png_custom and not png_hf_id:
print(
"Import PNG info: Unable to find a matching model for %s" % model_file
)
return png_custom, png_hf_id
def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str:
vae_custom = ""
if key in metadata:
vae_file = metadata[key]
vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae")
# VAE input is optional, should not print or throw an error if missing
return vae_custom
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
lora_custom = ""
if key in metadata:
lora_file = metadata[key]
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
def import_png_metadata(
pil_data,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
png_info = pil_data.info["parameters"]
metadata = parse_generation_parameters(png_info)
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
steps = int(metadata["Steps"])
cfg_scale = float(metadata["CFG scale"])
seed = int(metadata["Seed"])
width = float(metadata["Size-1"])
height = float(metadata["Size-2"])
if "Model" in metadata and png_custom_model:
custom_model = png_custom_model
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
if "Prompt" in metadata:
prompt = metadata["Prompt"]
if "Sampler" in metadata:
if metadata["Sampler"] in scheduler_model_map:
sampler = metadata["Sampler"]
else:
print(
"Import PNG info: Unable to find a scheduler for %s"
% metadata["Sampler"]
)
except Exception as ex:
if pil_data and pil_data.info.get("parameters"):
print("import_png_metadata failed with %s" % ex)
pass
return (
None,
prompt,
negative_prompt,
steps,
sampler,
cfg_scale,
seed,
width,
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
)

View File

@@ -0,0 +1,39 @@
import apps.shark_studio.web.utils.globals as global_obj
import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
else:
return f"{tab_name} complete"
def get_generation_text_info(seeds, device):
cfg_dump = {}
for cfg in global_obj.get_config_dict():
cfg_dump[cfg] = cfg
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
f"\nsteps={cfg_dump['steps']}, "
f"guidance_scale={cfg_dump['guidance_scale']}, "
f"seed={seeds}"
)
text_output += (
f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, "
if not cfg_dump.use_hiresfix
else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, "
)
text_output += (
f"batch_count={cfg_dump['batch_count']}, "
f"batch_size={cfg_dump['batch_size']}, "
f"max_length={cfg_dump['max_length']}"
)
return text_output

View File

@@ -0,0 +1,75 @@
import os
import shutil
from time import time
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
def clear_tmp_mlir():
cleanup_start = time()
print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(os.path.join(shark_tmp, filename))
print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
print(
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
)
# Clear all gradio tmp images from the last session
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
cleanup_start = time()
print(
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
)
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
print(
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
)
# older SHARK versions had to workaround gradio bugs and stored things differently
else:
image_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.startswith("tmp")
and filename.endswith(".png")
]
if len(image_files) > 0:
print(
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
)
cleanup_start = time()
for filename in image_files:
os.remove(shark_tmp + filename)
print(
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -10,7 +10,7 @@ from utils import get_datasets
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png")
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/amd-logo.jpg")
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:

View File

@@ -1,3 +1,3 @@
# SHARK Annotator
gradio==3.34.0
gradio==4.19.2
jsonlines

View File

@@ -5,6 +5,7 @@
from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
import os
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
@@ -16,51 +17,16 @@ else:
with open(path_to_transformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
paths_to_skipfiles = [Path(get_python_lib() + "/torch/_dynamo/skipfiles.py"), Path(get_python_lib() + "/torch/_dynamo/trace_rules.py")]
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
startMonitoring = 0
for line in fileinput.input(path_to_skipfiles, inplace=True):
if "SKIP_DIRS = " in line:
startMonitoring = 1
print(line, end="")
elif startMonitoring in [1, 2]:
if "]" in line:
startMonitoring += 1
for path in paths_to_skipfiles:
if not os.path.isfile(path):
continue
for line in fileinput.input(path, inplace=True):
if "[_module_dir(m) for m in BUILTIN_SKIPLIST]" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line:
print(f"{line.rstrip()} + [x.__name__ for x in BUILTIN_SKIPLIST]")
elif "(_module_dir(m) for m in BUILTIN_SKIPLIST)" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line:
print(line, end="")
print(f"SKIP_DIRS.extend(filter(None, (x.__name__ for x in BUILTIN_SKIPLIST)))")
else:
flag = True
for module in modules_to_comment:
if module in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
flag = False
break
if flag:
print(line, end="")
else:
print(line, end="")
# For getting around scikit-image's packaging, laze_loader has had a patch merged but yet to be released.
# Refer: https://github.com/scientific-python/lazy_loader
path_to_lazy_loader = Path(get_python_lib() + "/lazy_loader/__init__.py")
for line in fileinput.input(path_to_lazy_loader, inplace=True):
if 'stubfile = filename if filename.endswith("i")' in line:
print(
' stubfile = (filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi")',
end="",
)
else:
print(line, end="")
# For getting around timm's packaging.
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py")
for line in fileinput.input(path_to_timm_activations, inplace=True):
if "@torch.jit.script" in line:
print("@torch.jit._script_if_tracing", end="\n")
else:
print(line, end="")
print(line, end="")

View File

@@ -1,34 +0,0 @@
-f https://download.pytorch.org/whl/nightly/cpu/
--pre
numpy
torch
torchvision
tqdm
#iree-compiler | iree-runtime should already be installed
transformers
#jax[cpu]
# tflitehub dependencies.
Pillow
# web dependecies.
gradio
altair
# Testing and support.
#lit
#pyyaml
#ONNX and ORT for benchmarking
#--extra-index-url https://test.pypi.org/simple/
#protobuf
#coloredlogs
#flatbuffers
#sympy
#psutil
#onnx-weekly
#ort-nightly

View File

@@ -1,41 +0,0 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
numpy>1.22.4
pytorch-triton
torchvision
tabulate
tqdm
#iree-compiler | iree-runtime should already be installed
iree-tools-xla
# Modelling and JAX.
gin-config
transformers
diffusers
#jax[cpu]
Pillow
# Testing and support.
lit
pyyaml
python-dateutil
sacremoses
sentencepiece
# web dependecies.
gradio==3.44.3
altair
scipy
#ONNX and ORT for benchmarking
#--extra-index-url https://test.pypi.org/simple/
#protobuf
#coloredlogs
#flatbuffers
#sympy
#psutil
#onnx-weekly
#ort-nightly

View File

@@ -1,12 +1,16 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
-f https://openxla.github.io/iree/pip-release-links.html
-f https://download.pytorch.org/whl/nightly/cpu
-f https://iree.dev/pip-release-links.html
--pre
setuptools
wheel
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=python/turbine_models
torch==2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@deprecated-constraints#subdirectory=models
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# SHARK Runner
tqdm
@@ -16,8 +20,6 @@ google-cloud-storage
# Testing
pytest
pytest-xdist
pytest-forked
Pillow
parameterized
@@ -25,30 +27,19 @@ parameterized
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
transformers==4.37.1
torchsde # Required for Stable Diffusion SDE schedulers.
ftfy
gradio==4.8.0
gradio==4.29.0
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
safetensors==0.3.1
opencv-python
scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
py-cpuinfo
tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
einops # for zoedepth
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# For quantized GPTQ models
optimum
auto_gptq

View File

@@ -7,7 +7,7 @@ import glob
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "2.0.0"
backend_deps = []
setup(

View File

@@ -7,13 +7,13 @@
It checks the Python version installed and installs any required build
dependencies into a Python virtual environment.
If that environment does not exist, it creates it.
.PARAMETER update-src
git pulls latest version
.PARAMETER force
removes and recreates venv to force update of all dependencies
.EXAMPLE
.\setup_venv.ps1 --force
@@ -39,7 +39,7 @@ if ($arguments -eq "--force"){
Write-Host "deactivating..."
Deactivate
}
if (Test-Path .\shark.venv\) {
Write-Host "removing and recreating venv..."
Remove-Item .\shark.venv -Force -Recurse
@@ -88,10 +88,8 @@ else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"
pip install --pre -r requirements.txt
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
pip install -e .
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -49,58 +49,20 @@ Red=`tput setaf 1`
Green=`tput setaf 2`
Yellow=`tput setaf 3`
# Assume no binary torch-mlir.
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
torch_mlir_bin=false
if [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}Apple macOS detected"
if [[ $(uname -m) == 'arm64' ]]; then
echo "${Yellow}Apple M1 Detected"
hash rustc 2>/dev/null
if [ $? -eq 0 ];then
echo "${Green}rustc found to compile HF tokenizers"
else
echo "${Red}Could not find rustc" >&2
echo "${Red}Please run:"
echo "${Red}curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
exit 1
fi
fi
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
torch_mlir_bin=true
fi
elif [[ $(uname -s) = 'Linux' ]]; then
echo "${Yellow}Linux detected"
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
torch_mlir_bin=true
fi
else
echo "${Red}OS not detected. Pray and Play"
fi
# Upgrade pip and install requirements.
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
echo "Could not install torch-mlir" >&2
fi
fi
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
echo "${Red}Please build torch-mlir from source in your environment"
exit 1
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
if [ $? -eq 0 ];then
echo "Successfully Installed torch-mlir"
else
echo "Could not install torch-mlir" >&2
fi
fi
if [[ -z "${USE_IREE}" ]]; then
rm .use-iree
@@ -116,40 +78,13 @@ else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi
if [[ ! -z "${IMPORTER}" ]]; then
echo "${Yellow}Installing importer tools.."
if [[ $(uname -s) = 'Linux' ]]; then
echo "${Yellow}Linux detected.. installing Linux importer tools"
#Always get the importer tools from upstream IREE
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
elif [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}macOS detected.. installing macOS importer tools"
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
fi
fi
if [[ $(uname -s) = 'Darwin' ]]; then
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
else
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
T_VER_MIN=${T_VER:14:12}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
else
echo "Could not install torch + cu118." >&2
fi
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f ${RUNTIME} -f ${PYTORCH_URL}
if [[ -z "${NO_BREVITAS}" ]]; then
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev

View File

@@ -76,6 +76,7 @@ _IREE_DEVICE_MAP = {
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
@@ -91,9 +92,10 @@ _IREE_TARGET_MAP = {
"cpu-task": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"vulkan": "vulkan-spirv",
"metal": "metal",
"rocm": "rocm",
"hip": "rocm",
"intel-gpu": "opencl-spirv",
}
@@ -122,9 +124,7 @@ def check_device_drivers(device):
)
return True
except RuntimeError as re:
print(
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
)
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
return True
# Unknown device. We assume drivers are installed.

View File

@@ -62,8 +62,19 @@ def get_iree_device_args(device, extra_args=[]):
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
if device == "hip":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
return []
def get_iree_target_triple(device):
args = get_iree_device_args(device)
for flag in args:
if "triple" in flag:
triple = flag.split("=")[-1]
return triple
return ""
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
@@ -81,9 +92,9 @@ def clean_device_info(raw_device):
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
if device not in ["hip", "rocm", "vulkan"]:
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
if device in ["hip", "rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
@@ -105,9 +116,8 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--mlir-elide-elementsattrs-if-larger=10",
]
if debug == True:
common_args.extend(

View File

@@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args):
return None
def get_rocm_device_arch(device_num=0, extra_args=[]):
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
# ROCM Device Arch selection:
# 1 : User given device arch using `--iree-rocm-target-chip` flag
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
@@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
arch_in_device_dump = None
# get rocm arch from iree dump devices
def get_devices_info_from_dump(dump):
def get_devices_info_from_dump(dump, driver):
from os import linesep
dump_clean = list(
filter(
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
dump.split(linesep),
if driver == "hip":
dump_clean = list(
filter(
lambda s: "AMD" in s,
dump.split(linesep),
)
)
else:
dump_clean = list(
filter(
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
dump.split(linesep),
)
)
)
arch_pairs = [
(
dump_clean[i].split("=")[1].strip(),
@@ -87,16 +95,17 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
return arch_pairs
dump_device_info = None
driver = "hip" if hip_driver else "rocm"
try:
dump_device_info = run_cmd(
"iree-run-module --dump_devices=rocm", raise_err=True
"iree-run-module --dump_devices=" + driver, raise_err=True
)
except Exception as e:
print("could not execute `iree-run-module --dump_devices=rocm`")
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
if dump_device_info is not None:
device_num = 0 if device_num is None else device_num
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
@@ -107,24 +116,22 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
default_rocm_arch = "gfx1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."
"\n or from `iree-run-module --dump_devices` command."
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
)
return default_rocm_arch
# Get the default gpu args given the architecture.
def get_iree_rocm_args(device_num=0, extra_args=[]):
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
rocm_flags = ["--iree-rocm-link-bc=true"]
rocm_flags = []
if check_rocm_device_arch_in_args(extra_args) is None:
rocm_arch = get_rocm_device_arch(device_num, extra_args)
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
return rocm_flags
# Some constants taken from cuda.h
CUDA_SUCCESS = 0
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16

View File

@@ -33,7 +33,7 @@ def get_vulkan_target_env(vulkan_target_triple):
device_type = get_device_type(triple)
# get capabilities
capabilities = get_vulkan_target_capabilities(triple)
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
return target_env
@@ -63,62 +63,62 @@ def get_extensions(triple):
arch, product, os = triple
if arch == "m1":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"SPV_KHR_16bit_storage",
"SPV_KHR_8bit_storage",
"SPV_KHR_shader_float16_int8",
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "valhall":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"SPV_KHR_16bit_storage",
"SPV_KHR_8bit_storage",
"SPV_KHR_shader_float16_int8",
"SPV_KHR_spirv_1_4",
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
if arch == "adreno":
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"SPV_KHR_16bit_storage",
"SPV_KHR_shader_float16_int8",
"SPV_KHR_spirv_1_4",
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_variable_pointers",
]
if os == "android31":
ext.append("VK_KHR_8bit_storage")
ext.append("SPV_KHR_8bit_storage")
return make_ext_list(ext_list=ext)
if get_vendor(triple) == "SwiftShader":
ext = ["VK_KHR_storage_buffer_storage_class"]
ext = ["SPV_KHR_storage_buffer_storage_class"]
return make_ext_list(ext_list=ext)
if arch == "unknown":
ext = [
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_variable_pointers",
]
return make_ext_list(ext_list=ext)
ext = [
"VK_KHR_16bit_storage",
"VK_KHR_8bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_spirv_1_4",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_variable_pointers",
"SPV_KHR_16bit_storage",
"SPV_KHR_8bit_storage",
"SPV_KHR_shader_float16_int8",
"SPV_KHR_spirv_1_4",
"SPV_KHR_storage_buffer_storage_class",
"SPV_KHR_variable_pointers",
"VK_EXT_subgroup_size_control",
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_KHR_cooperative_matrix")
ext.append("SPV_KHR_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
ext.append("SPV_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@@ -186,13 +186,13 @@ def get_vulkan_target_capabilities(triple):
"Quad": 128,
"PartitionedNV": 256,
}
cap["maxComputeSharedMemorySize"] = 16384
cap["maxComputeWorkGroupInvocations"] = 128
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
cap["subgroupSize"] = 32
cap["max_compute_shared_memory_size"] = 16384
cap["max_compute_workgroup_invocations"] = 128
cap["max_compute_workgroup_size"] = [128, 128, 64]
cap["subgroup_size"] = 32
cap["subgroupFeatures"] = ["Basic"]
cap["minSubgroupSize"] = None
cap["maxSubgroupSize"] = None
cap["min_subgroup_size"] = None
cap["max_subgroup_size"] = None
cap["shaderFloat16"] = False
cap["shaderFloat64"] = False
cap["shaderInt8"] = False
@@ -209,13 +209,13 @@ def get_vulkan_target_capabilities(triple):
cap["coopmatCases"] = None
if arch in ["rdna1", "rdna2", "rdna3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["max_compute_shared_memory_size"] = 65536
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 64
cap["subgroup_size"] = 64
cap["min_subgroup_size"] = 32
cap["max_subgroup_size"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -244,7 +244,8 @@ def get_vulkan_target_capabilities(triple):
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
]
if product == "rx5700xt":
@@ -252,11 +253,11 @@ def get_vulkan_target_capabilities(triple):
cap["storagePushConstant8"] = False
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
cap["maxComputeSharedMemorySize"] = 65536
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["max_compute_shared_memory_size"] = 65536
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
cap["subgroupSize"] = 64
cap["subgroup_size"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -267,8 +268,8 @@ def get_vulkan_target_capabilities(triple):
"Clustered",
"Quad",
]
cap["minSubgroupSize"] = 64
cap["maxSubgroupSize"] = 64
cap["min_subgroup_size"] = 64
cap["max_subgroup_size"] = 64
if arch == "rgcn5":
cap["shaderFloat16"] = True
@@ -290,11 +291,11 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
elif arch == "m1":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["max_compute_shared_memory_size"] = 32768
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["subgroup_size"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -321,11 +322,11 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
elif arch == "valhall":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 512
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
cap["max_compute_shared_memory_size"] = 32768
cap["max_compute_workgroup_invocations"] = 512
cap["max_compute_workgroup_size"] = [512, 512, 512]
cap["subgroupSize"] = 16
cap["subgroup_size"] = 16
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -352,11 +353,11 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
elif arch == "arc":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["max_compute_shared_memory_size"] = 32768
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
cap["subgroupSize"] = 32
cap["subgroup_size"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -385,8 +386,8 @@ def get_vulkan_target_capabilities(triple):
elif arch == "cpu":
if product == "swiftshader":
cap["maxComputeSharedMemorySize"] = 16384
cap["subgroupSize"] = 4
cap["max_compute_shared_memory_size"] = 16384
cap["subgroup_size"] = 4
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -397,13 +398,13 @@ def get_vulkan_target_capabilities(triple):
]
elif arch in ["pascal"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1536
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
cap["max_compute_shared_memory_size"] = 49152
cap["max_compute_workgroup_invocations"] = 1536
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroup_size"] = 32
cap["min_subgroup_size"] = 32
cap["max_subgroup_size"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -431,13 +432,13 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
elif arch in ["ampere", "turing"]:
cap["maxComputeSharedMemorySize"] = 49152
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
cap["max_compute_shared_memory_size"] = 49152
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
cap["subgroupSize"] = 32
cap["minSubgroupSize"] = 32
cap["maxSubgroupSize"] = 32
cap["subgroup_size"] = 32
cap["min_subgroup_size"] = 32
cap["max_subgroup_size"] = 32
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -471,11 +472,11 @@ def get_vulkan_target_capabilities(triple):
]
elif arch == "adreno":
cap["maxComputeSharedMemorySize"] = 32768
cap["maxComputeWorkGroupInvocations"] = 1024
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
cap["max_compute_shared_memory_size"] = 32768
cap["max_compute_workgroup_invocations"] = 1024
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
cap["subgroupSize"] = 64
cap["subgroup_size"] = 64
cap["subgroupFeatures"] = [
"Basic",
"Vote",
@@ -491,14 +492,14 @@ def get_vulkan_target_capabilities(triple):
cap["shaderInt16"] = True
cap["storageBuffer16BitAccess"] = True
if os == "andorid31":
if os == "android31":
cap["uniformAndStorageBuffer8BitAccess"] = True
cap["variablePointers"] = True
cap["variablePointersStorageBuffer"] = True
elif arch == "unknown":
cap["subgroupSize"] = 64
cap["subgroup_size"] = 64
cap["variablePointers"] = False
cap["variablePointersStorageBuffer"] = False
else:
@@ -521,14 +522,14 @@ def get_vulkan_target_capabilities(triple):
res += f"{k} = {'unit' if v == True else None}, "
elif isinstance(v, list):
if k == "subgroupFeatures":
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
elif k == "maxComputeWorkGroupSize":
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
elif k == "max_compute_workgroup_size":
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
elif k == "coopmatCases":
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:

View File

@@ -144,6 +144,8 @@ def get_vulkan_target_triple(device_name):
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"
elif "v620" in device_name:
triple = f"rdna2-v620-{system_os}"
# Adreno Targets
elif all(x in device_name for x in ("Adreno", "740")):
@@ -169,7 +171,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
print(
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
)
return f"-iree-vulkan-target-triple={triple}"
return f"--iree-vulkan-target-triple={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
@@ -184,7 +186,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
res_vulkan_flag = []
res_vulkan_flag += [
"--iree-stream-resource-max-allocation-size=3221225472"
"--iree-stream-resource-max-allocation-size=3221225472",
"--iree-flow-inline-constants-max-byte-length=0"
]
vulkan_triple_flag = None
for arg in extra_args:
@@ -197,10 +200,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
res_vulkan_flag += [vulkan_triple_flag]
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
res_vulkan_flag.append(vulkan_target_env)
return res_vulkan_flag

View File

@@ -6,6 +6,7 @@ import tempfile
import os
import hashlib
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
def create_hash(file_name):
with open(file_name, "rb") as f:
@@ -120,7 +121,7 @@ class SharkImporter:
is_dynamic=False,
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
save_dir=cmd_opts.tmp_dir, #"./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
@@ -806,7 +807,7 @@ def save_mlir(
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = os.path.join(".", "shark_tmp")
dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):