mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
344 Commits
20230516.7
...
fp16cpu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
489a858af1 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 | ||
|
|
8738571d1e | ||
|
|
a4c354ce54 | ||
|
|
cc53efa89f | ||
|
|
9ae8bc921e | ||
|
|
32eb78f0f9 | ||
|
|
cb509343d9 | ||
|
|
6da391c9b1 | ||
|
|
9dee7ae652 | ||
|
|
343dfd901c | ||
|
|
57260b9c37 | ||
|
|
18e7d2d061 | ||
|
|
51a1009796 | ||
|
|
045c3c3852 | ||
|
|
0139dd58d9 | ||
|
|
c96571855a | ||
|
|
4f61d69d86 | ||
|
|
531d447768 | ||
|
|
16f46f8de9 | ||
|
|
c4723f469f | ||
|
|
d804f45a61 | ||
|
|
d22177f936 | ||
|
|
75e68f02f4 | ||
|
|
4dc9c59611 | ||
|
|
18801dcabc | ||
|
|
3c577f7168 | ||
|
|
f5e4fa6ffe | ||
|
|
48de445325 | ||
|
|
8e90f1b81a | ||
|
|
e8c1203be2 | ||
|
|
e4d7abb519 | ||
|
|
96185c9dc1 | ||
|
|
bc22a81925 | ||
|
|
5203679f1f | ||
|
|
bf073f8f37 | ||
|
|
cec6eda6b4 | ||
|
|
9e37e03741 | ||
|
|
9b8c4401b5 | ||
|
|
a9f95a218b | ||
|
|
872bd72d0b | ||
|
|
fd1c4db5d0 | ||
|
|
759664bb48 | ||
|
|
14fd0cdd87 | ||
|
|
a57eccc997 | ||
|
|
a686d7d89f | ||
|
|
ed484b8253 | ||
|
|
7fe57ebaaf | ||
|
|
c287fd2be8 | ||
|
|
51ec1a1360 | ||
|
|
bd30044c0b | ||
|
|
c9de2729b2 | ||
|
|
a5b13fcc2f | ||
|
|
6bb329c4af | ||
|
|
98fb6c52df | ||
|
|
206c1b70f4 | ||
|
|
cdb037ee54 | ||
|
|
ce2fd84538 | ||
|
|
4684afad34 | ||
|
|
8d65456b7a | ||
|
|
d6759a852b | ||
|
|
ab57af43c1 | ||
|
|
4d5c55dd9f | ||
|
|
07399ad65c | ||
|
|
776a9c2293 | ||
|
|
9d399eb988 | ||
|
|
927b662aa7 | ||
|
|
47f8a79c75 | ||
|
|
289f983f41 | ||
|
|
453e46562f | ||
|
|
5497af1f56 | ||
|
|
f3cb63fc9c | ||
|
|
d7092aafaa | ||
|
|
a415f3f70e | ||
|
|
c292e5c9d7 | ||
|
|
03c4d9e171 | ||
|
|
3662224c04 | ||
|
|
db3f222933 | ||
|
|
68b3021325 | ||
|
|
336469154d | ||
|
|
41e5088908 | ||
|
|
0a8f7673f4 | ||
|
|
c482ab78da | ||
|
|
4be80f7158 | ||
|
|
536aba1424 | ||
|
|
dd738a0e02 | ||
|
|
8927cb0a2c | ||
|
|
8c317e4809 | ||
|
|
b0136593df | ||
|
|
11f62d7fac | ||
|
|
14559dd620 | ||
|
|
e503a3e8d6 | ||
|
|
22a4254adf | ||
|
|
ab01f0f048 | ||
|
|
c471d17cca | ||
|
|
a2a436eb0c | ||
|
|
1adb51b29d | ||
|
|
aab2233e25 | ||
|
|
e20cd71314 | ||
|
|
5ec91143f5 | ||
|
|
7cf19230e2 | ||
|
|
1bcf6b2c5b | ||
|
|
91027f8719 | ||
|
|
a909fc2e78 | ||
|
|
247f69cf9d | ||
|
|
3b8f7cc231 | ||
|
|
6e8dbf72bd | ||
|
|
38e5b62d80 | ||
|
|
1c7eecc981 | ||
|
|
be417f0bf4 | ||
|
|
a517e217b0 | ||
|
|
9fcae4f808 | ||
|
|
788d469c5b | ||
|
|
8a59f7cc27 | ||
|
|
1c2ec3c7a2 | ||
|
|
af0f715e20 | ||
|
|
47ec7275e6 | ||
|
|
3a24cff901 | ||
|
|
1f72907886 | ||
|
|
06c8aabd01 | ||
|
|
55a12cc0c4 | ||
|
|
7dcbbde523 | ||
|
|
1b62dc4529 | ||
|
|
c5a47887f4 | ||
|
|
c72d0eaf87 | ||
|
|
c41f58042a | ||
|
|
043e5a5c7a | ||
|
|
a1b1ce935c | ||
|
|
bc6fee1a0c | ||
|
|
91ab594744 | ||
|
|
4015793f84 | ||
|
|
d63ce76dd8 | ||
|
|
1c32915570 | ||
|
|
6d286c0609 | ||
|
|
7392b22731 | ||
|
|
534de05791 | ||
|
|
5779e8c039 | ||
|
|
d496053590 | ||
|
|
6274a813c9 | ||
|
|
1d6a1f9f8a | ||
|
|
75672c0e28 | ||
|
|
74a7202173 | ||
|
|
27a08735db | ||
|
|
eaa49cce17 | ||
|
|
10657d6fb1 | ||
|
|
e3ab844cd1 | ||
|
|
5ce6001b41 | ||
|
|
501d0ca52e | ||
|
|
b444528715 | ||
|
|
6e6c90f62b | ||
|
|
8cdb38496e | ||
|
|
726d73d6ba | ||
|
|
4d55e51d46 | ||
|
|
6ef78ee7ba | ||
|
|
4002da7161 | ||
|
|
ecb5e8e5d8 | ||
|
|
28e0919321 | ||
|
|
28f4d44a6b | ||
|
|
97f7e79391 | ||
|
|
44a8f2f8db | ||
|
|
8822b9acd7 | ||
|
|
0ca3b9fce3 | ||
|
|
045f2bb147 | ||
|
|
a811b867b9 | ||
|
|
cdd505e2dd | ||
|
|
1b0f39107c | ||
|
|
b9b8955f74 | ||
|
|
6f7a85eee3 | ||
|
|
18c8e9e51e | ||
|
|
a202bb466a | ||
|
|
07c1e1d712 | ||
|
|
18daec78c8 | ||
|
|
1a8e2024d6 | ||
|
|
d61b6641fb | ||
|
|
88cc2423cc | ||
|
|
ccf944c1bd | ||
|
|
0def74f520 | ||
|
|
3fb72e192e | ||
|
|
855435ee24 | ||
|
|
6f9f868fc0 | ||
|
|
fb865f1b99 | ||
|
|
3e5c50f07b | ||
|
|
a544f30a8f | ||
|
|
1fe56d460a | ||
|
|
fafd713141 | ||
|
|
015d0132c3 | ||
|
|
20ddd96ef7 | ||
|
|
ee33cfd2d1 | ||
|
|
a3cba21d5b | ||
|
|
a7b6ec4095 | ||
|
|
d80b087d95 | ||
|
|
297a209608 | ||
|
|
b204113563 | ||
|
|
f60ab1f4fa | ||
|
|
b203779462 | ||
|
|
38570a9bbb | ||
|
|
a5c882f296 | ||
|
|
eb6d11cfed | ||
|
|
46184a81ac | ||
|
|
149165a2f0 | ||
|
|
bec82a665f | ||
|
|
9551490341 | ||
|
|
49b3ecdbca | ||
|
|
f53e3594c3 | ||
|
|
5562d1dfda | ||
|
|
c7b0c2961e | ||
|
|
44273b0791 | ||
|
|
0a4c8fcb3e | ||
|
|
2fec3c8169 | ||
|
|
5e7d5930dd | ||
|
|
b6dbd20250 | ||
|
|
34f1295349 | ||
|
|
1980d7b2c3 | ||
|
|
2cfacc5051 | ||
|
|
436f58ddc4 | ||
|
|
6b29bd17c8 | ||
|
|
2c3485ca3e | ||
|
|
f206ecc635 | ||
|
|
a187e05ae6 | ||
|
|
8c21960486 | ||
|
|
be62fce676 | ||
|
|
f23b778a6c | ||
|
|
436edf900d | ||
|
|
ed58c2553f | ||
|
|
f2ca58e844 | ||
|
|
1dbcc736eb | ||
|
|
a83808ddc5 | ||
|
|
a07fe80530 | ||
|
|
d0ba3ef8fa | ||
|
|
8400529c2c | ||
|
|
7eaee9c242 | ||
|
|
8230eebce5 | ||
|
|
6296ea4be9 | ||
|
|
4151ec3a8f | ||
|
|
a2467e8d43 | ||
|
|
e677178bcc | ||
|
|
7ef1bea953 | ||
|
|
ad89bb1413 | ||
|
|
218ed78c40 | ||
|
|
6046f36ab6 | ||
|
|
5915bf7de3 | ||
|
|
f0a4e59758 | ||
|
|
1ddef26af5 | ||
|
|
ba8eddb12f | ||
|
|
47b346d428 | ||
|
|
1b4f4f5f4d | ||
|
|
73cd7e8320 | ||
|
|
19c0ae3702 | ||
|
|
54e57f7771 | ||
|
|
6d64b8e273 | ||
|
|
a8ea0326f5 | ||
|
|
58e9194553 | ||
|
|
eb360e255d | ||
|
|
a6f88d7f72 | ||
|
|
8e571d165f | ||
|
|
3cddd01b10 | ||
|
|
64c2b2d96b | ||
|
|
f5ce121988 | ||
|
|
991f144598 | ||
|
|
09bea17e59 | ||
|
|
aefcf80b48 | ||
|
|
512235892e | ||
|
|
6602a2f5ba |
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
|
||||
30
.github/workflows/nightly.yml
vendored
30
.github/workflows/nightly.yml
vendored
@@ -50,27 +50,13 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/shark_sd.exe ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_${{ env.package_version_ }}.exe
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd_cli.spec
|
||||
python process_skipfiles.py
|
||||
mv ./dist/shark_sd_cli.exe ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/shark_sd_cli_${{ env.package_version_ }}.exe
|
||||
|
||||
|
||||
# GHA windows VM OOMs so disable for now
|
||||
#- name: Build and validate the SHARK Runtime package
|
||||
# shell: powershell
|
||||
# run: |
|
||||
# $env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
# pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
|
||||
#- uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# path: dist/*
|
||||
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
@@ -78,7 +64,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ./dist/*
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
@@ -118,7 +104,7 @@ jobs:
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -158,7 +144,7 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
14
.github/workflows/test-models.yml
vendored
14
.github/workflows/test-models.yml
vendored
@@ -35,6 +35,8 @@ jobs:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
- os: MacStudio
|
||||
suite: metal
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
@@ -46,6 +48,8 @@ jobs:
|
||||
suite: cuda
|
||||
- os: MacStudio
|
||||
suite: cpu
|
||||
- os: MacStudio
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
suite: vulkan
|
||||
- os: icelake
|
||||
@@ -61,7 +65,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
if: matrix.os != '7950x'
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
@@ -84,9 +87,6 @@ jobs:
|
||||
#cache-dependency-path: |
|
||||
# **/requirements-importer.txt
|
||||
# **/requirements.txt
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
if: matrix.os == '7950x'
|
||||
|
||||
- name: Install dependencies
|
||||
if: matrix.suite == 'lint'
|
||||
@@ -115,6 +115,7 @@ jobs:
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
python build_tools/vicuna_testing.py
|
||||
|
||||
- name: Validate Models on NVIDIA GPU
|
||||
if: matrix.suite == 'cuda'
|
||||
@@ -129,15 +130,14 @@ jobs:
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
|
||||
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
export DYLD_LIBRARY_PATH=/usr/local/lib/
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -2,6 +2,8 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
@@ -157,7 +159,7 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
@@ -187,3 +189,13 @@ apps/stable_diffusion/web/models/
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
branch = shark-06032022
|
||||
|
||||
@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
|
||||
16
apps/language_models/README.md
Normal file
16
apps/language_models/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## CodeGen Setup using SHARK-server
|
||||
|
||||
### Setup Server
|
||||
- clone SHARK and setup the venv
|
||||
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
|
||||
- default server address is `http://0.0.0.0:8080`
|
||||
|
||||
### Setup Client
|
||||
1. fauxpilot-vscode (VSCode Extension):
|
||||
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
|
||||
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
|
||||
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
|
||||
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
|
||||
|
||||
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
|
||||
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.
|
||||
18
apps/language_models/langchain/README.md
Normal file
18
apps/language_models/langchain/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Langchain
|
||||
|
||||
## How to run the model
|
||||
|
||||
1.) Install all the dependencies by running:
|
||||
```shell
|
||||
pip install -r apps/language_models/langchain/langchain_requirements.txt
|
||||
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
```
|
||||
|
||||
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
|
||||
|
||||
Now, you are ready to use the model.
|
||||
|
||||
3.) To run the model, run the following command:
|
||||
```shell
|
||||
python apps/language_models/langchain/gen.py --cli=True
|
||||
```
|
||||
186
apps/language_models/langchain/cli.py
Normal file
186
apps/language_models/langchain/cli.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from evaluate_params import eval_func_param_names
|
||||
from gen import Langchain
|
||||
from prompter import non_hf_types
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
|
||||
def run_cli( # for local function:
|
||||
base_model=None,
|
||||
lora_weights=None,
|
||||
inference_server=None,
|
||||
debug=None,
|
||||
chat_context=None,
|
||||
examples=None,
|
||||
memory_restriction_level=None,
|
||||
# for get_model:
|
||||
score_model=None,
|
||||
load_8bit=None,
|
||||
load_4bit=None,
|
||||
load_half=None,
|
||||
load_gptq=None,
|
||||
use_safetensors=None,
|
||||
infer_devices=None,
|
||||
tokenizer_base_model=None,
|
||||
gpu_id=None,
|
||||
local_files_only=None,
|
||||
resume_download=None,
|
||||
use_auth_token=None,
|
||||
trust_remote_code=None,
|
||||
offload_folder=None,
|
||||
compile_model=None,
|
||||
# for some evaluate args
|
||||
stream_output=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
num_beams=None,
|
||||
max_new_tokens=None,
|
||||
min_new_tokens=None,
|
||||
early_stopping=None,
|
||||
max_time=None,
|
||||
repetition_penalty=None,
|
||||
num_return_sequences=None,
|
||||
do_sample=None,
|
||||
chat=None,
|
||||
langchain_mode=None,
|
||||
langchain_action=None,
|
||||
document_choice=None,
|
||||
top_k_docs=None,
|
||||
chunk=None,
|
||||
chunk_size=None,
|
||||
# for evaluate kwargs
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
concurrency_count=None,
|
||||
save_dir=None,
|
||||
sanitize_bot_response=None,
|
||||
model_state0=None,
|
||||
max_max_new_tokens=None,
|
||||
is_public=None,
|
||||
max_max_time=None,
|
||||
raise_generate_gpu_exceptions=None,
|
||||
load_db_if_exists=None,
|
||||
dbs=None,
|
||||
user_path=None,
|
||||
detect_user_path_changes_every_query=None,
|
||||
use_openai_embedding=None,
|
||||
use_openai_model=None,
|
||||
hf_embedding_model=None,
|
||||
db_type=None,
|
||||
n_jobs=None,
|
||||
first_para=None,
|
||||
text_limit=None,
|
||||
verbose=None,
|
||||
cli=None,
|
||||
reverse_docs=None,
|
||||
use_cache=None,
|
||||
auto_reduce_chunks=None,
|
||||
max_chunks=None,
|
||||
model_lock=None,
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
# unique to this function:
|
||||
cli_loop=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
score_model = "" # FIXME: For now, so user doesn't have to pass
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
device = "cpu" if n_gpus == 0 else "cuda"
|
||||
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
|
||||
|
||||
with context_class(device):
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
tokenizer_base_model=tokenizer_base_model,
|
||||
lora_weights=lora_weights,
|
||||
inference_server=inference_server,
|
||||
prompt_type=prompt_type,
|
||||
prompt_dict=prompt_dict,
|
||||
)
|
||||
model_state = dict(model=model, tokenizer=tokenizer, device=device)
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
example1 = examples[-1] # pick reference example
|
||||
all_generations = []
|
||||
while True:
|
||||
clear_torch_cache()
|
||||
instruction = input("\nEnter an instruction: ")
|
||||
if instruction == "exit":
|
||||
break
|
||||
|
||||
eval_vars = copy.deepcopy(example1)
|
||||
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
|
||||
eval_func_param_names.index("instruction_nochat")
|
||||
] = instruction
|
||||
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
|
||||
eval_func_param_names.index("iinput_nochat")
|
||||
] = "" # no input yet
|
||||
eval_vars[
|
||||
eval_func_param_names.index("context")
|
||||
] = "" # no context yet
|
||||
|
||||
# grab other parameters, like langchain_mode
|
||||
for k in eval_func_param_names:
|
||||
if k in locals():
|
||||
eval_vars[eval_func_param_names.index(k)] = locals()[k]
|
||||
|
||||
gener = fun(*tuple(eval_vars))
|
||||
outr = ""
|
||||
res_old = ""
|
||||
for gen_output in gener:
|
||||
res = gen_output["response"]
|
||||
extra = gen_output["sources"]
|
||||
if base_model not in non_hf_types or base_model in ["llama"]:
|
||||
if not stream_output:
|
||||
print(res)
|
||||
else:
|
||||
# then stream output for gradio that has full output each generation, so need here to show only new chars
|
||||
diff = res[len(res_old) :]
|
||||
print(diff, end="", flush=True)
|
||||
res_old = res
|
||||
outr = res # don't accumulate
|
||||
else:
|
||||
outr += res # just is one thing
|
||||
if extra:
|
||||
# show sources at end after model itself had streamed to std rest of response
|
||||
print(extra, flush=True)
|
||||
all_generations.append(outr + "\n")
|
||||
if not cli_loop:
|
||||
break
|
||||
return all_generations
|
||||
2187
apps/language_models/langchain/create_data.py
Normal file
2187
apps/language_models/langchain/create_data.py
Normal file
File diff suppressed because it is too large
Load Diff
103
apps/language_models/langchain/enums.py
Normal file
103
apps/language_models/langchain/enums.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
custom = -1
|
||||
plain = 0
|
||||
instruct = 1
|
||||
quality = 2
|
||||
human_bot = 3
|
||||
dai_faq = 4
|
||||
summarize = 5
|
||||
simple_instruct = 6
|
||||
instruct_vicuna = 7
|
||||
instruct_with_end = 8
|
||||
human_bot_orig = 9
|
||||
prompt_answer = 10
|
||||
open_assistant = 11
|
||||
wizard_lm = 12
|
||||
wizard_mega = 13
|
||||
instruct_vicuna2 = 14
|
||||
instruct_vicuna3 = 15
|
||||
wizard2 = 16
|
||||
wizard3 = 17
|
||||
instruct_simple = 18
|
||||
wizard_vicuna = 19
|
||||
openai = 20
|
||||
openai_chat = 21
|
||||
gptj = 22
|
||||
prompt_answer_openllama = 23
|
||||
vicuna11 = 24
|
||||
mptinstruct = 25
|
||||
mptchat = 26
|
||||
falcon = 27
|
||||
|
||||
|
||||
class DocumentChoices(Enum):
|
||||
All_Relevant = 0
|
||||
All_Relevant_Only_Sources = 1
|
||||
Only_All_Sources = 2
|
||||
Just_LLM = 3
|
||||
|
||||
|
||||
non_query_commands = [
|
||||
DocumentChoices.All_Relevant_Only_Sources.name,
|
||||
DocumentChoices.Only_All_Sources.name,
|
||||
]
|
||||
|
||||
|
||||
class LangChainMode(Enum):
|
||||
"""LangChain mode"""
|
||||
|
||||
DISABLED = "Disabled"
|
||||
CHAT_LLM = "ChatLLM"
|
||||
LLM = "LLM"
|
||||
ALL = "All"
|
||||
WIKI = "wiki"
|
||||
WIKI_FULL = "wiki_full"
|
||||
USER_DATA = "UserData"
|
||||
MY_DATA = "MyData"
|
||||
GITHUB_H2OGPT = "github h2oGPT"
|
||||
H2O_DAI_DOCS = "DriverlessAI docs"
|
||||
|
||||
|
||||
class LangChainAction(Enum):
|
||||
"""LangChain action"""
|
||||
|
||||
QUERY = "Query"
|
||||
# WIP:
|
||||
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
||||
SUMMARIZE_MAP = "Summarize"
|
||||
SUMMARIZE_ALL = "Summarize_all"
|
||||
SUMMARIZE_REFINE = "Summarize_refine"
|
||||
|
||||
|
||||
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
|
||||
|
||||
# from site-packages/langchain/llms/openai.py
|
||||
# but needed since ChatOpenAI doesn't have this information
|
||||
model_token_mapping = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16 * 1024,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"text-ada-001": 2049,
|
||||
"ada": 2049,
|
||||
"text-babbage-001": 2040,
|
||||
"babbage": 2049,
|
||||
"text-curie-001": 2049,
|
||||
"curie": 2049,
|
||||
"davinci": 2049,
|
||||
"text-davinci-003": 4097,
|
||||
"text-davinci-002": 4097,
|
||||
"code-davinci-002": 8001,
|
||||
"code-davinci-001": 8001,
|
||||
"code-cushman-002": 2048,
|
||||
"code-cushman-001": 2048,
|
||||
}
|
||||
|
||||
source_prefix = "Sources [Score | Link]:"
|
||||
source_postfix = "End Sources<p>"
|
||||
53
apps/language_models/langchain/evaluate_params.py
Normal file
53
apps/language_models/langchain/evaluate_params.py
Normal file
@@ -0,0 +1,53 @@
|
||||
no_default_param_names = [
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
]
|
||||
|
||||
gen_hyper = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"num_beams",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"early_stopping",
|
||||
"max_time",
|
||||
"repetition_penalty",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
|
||||
eval_func_param_names = (
|
||||
[
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"stream_output",
|
||||
"prompt_type",
|
||||
"prompt_dict",
|
||||
]
|
||||
+ gen_hyper
|
||||
+ [
|
||||
"chat",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
"langchain_mode",
|
||||
"langchain_action",
|
||||
"top_k_docs",
|
||||
"chunk",
|
||||
"chunk_size",
|
||||
"document_choice",
|
||||
]
|
||||
)
|
||||
|
||||
# form evaluate defaults for submit_nochat_api
|
||||
eval_func_param_names_defaults = eval_func_param_names.copy()
|
||||
for k in no_default_param_names:
|
||||
if k in eval_func_param_names_defaults:
|
||||
eval_func_param_names_defaults.remove(k)
|
||||
|
||||
|
||||
eval_extra_columns = ["prompt", "response", "score"]
|
||||
846
apps/language_models/langchain/expanded_pipelines.py
Normal file
846
apps/language_models/langchain/expanded_pipelines.py
Normal file
@@ -0,0 +1,846 @@
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
input_docs = inputs["input_documents"]
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
|
||||
if "is_first" in inputs.keys() and not inputs["is_first"]:
|
||||
run_manager_ = run_manager
|
||||
input_list = [inputs]
|
||||
stop = None
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager_:
|
||||
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompts = prompt_strings
|
||||
callbacks = run_manager_.get_child() if run_manager_ else None
|
||||
tags = None
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
params = self.llm.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
disregard_cache = self.llm.cache is not None and not self.llm.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.llm.callbacks,
|
||||
self.llm.verbose,
|
||||
tags,
|
||||
self.llm.tags,
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.llm.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager_ = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
inputs_ = prompt
|
||||
num_workers = None
|
||||
batch_size = None
|
||||
|
||||
if num_workers is None:
|
||||
if self.llm.pipeline._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self.llm.pipeline._num_workers
|
||||
if batch_size is None:
|
||||
if self.llm.pipeline._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self.llm.pipeline._batch_size
|
||||
|
||||
preprocess_params = {}
|
||||
generate_kwargs = {}
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {}
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self.llm.pipeline._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
forward_params = {
|
||||
**self.llm.pipeline._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
postprocess_params = {
|
||||
**self.llm.pipeline._postprocess_params,
|
||||
**postprocess_params,
|
||||
}
|
||||
|
||||
self.llm.pipeline.call_count += 1
|
||||
if (
|
||||
self.llm.pipeline.call_count > 10
|
||||
and self.llm.pipeline.framework == "pt"
|
||||
and self.llm.pipeline.device.type == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
||||
" dataset",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
model_inputs = self.llm.pipeline.preprocess(
|
||||
inputs_, **preprocess_params
|
||||
)
|
||||
model_outputs = self.llm.pipeline.forward(
|
||||
model_inputs, **forward_params
|
||||
)
|
||||
model_outputs["process"] = False
|
||||
return model_outputs
|
||||
output = LLMResult(generations=generations)
|
||||
run_manager_.on_llm_end(output)
|
||||
if run_manager_:
|
||||
output.run = RunInfo(run_id=run_manager_.run_id)
|
||||
response = output
|
||||
|
||||
outputs = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
][0]
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
else:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {
|
||||
k: v for k, v in inputs.items() if k != self.input_key
|
||||
}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
inputs["is_first"] = False
|
||||
inputs["input_documents"] = input_docs
|
||||
|
||||
# Call predict on the LLM.
|
||||
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
if "process" in output.keys() and not output["process"]:
|
||||
return output
|
||||
output = output[self.llm_chain.output_key]
|
||||
extra_return_dict = {}
|
||||
extra_return_dict[self.output_key] = output
|
||||
outputs = extra_return_dict
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
Returns None if the method does not depend on the prompt length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMChain, OpenAI, PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in result
|
||||
]
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
document_separator: str = "\n\n"
|
||||
"""The string with which to join the formatted documents"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if "document_variable_name" not in values:
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain_variables"
|
||||
)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
return inputs
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Get the prompt length by formatting the prompt."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", "map_rerank", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
return loader_mapping[chain_type](
|
||||
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
1945
apps/language_models/langchain/gen.py
Normal file
1945
apps/language_models/langchain/gen.py
Normal file
File diff suppressed because it is too large
Load Diff
380
apps/language_models/langchain/gpt4all_llm.py
Normal file
380
apps/language_models/langchain/gpt4all_llm.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict, Any, Optional, List
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from pydantic import root_validator
|
||||
from langchain.llms import gpt4all
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from utils import FakeTokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
||||
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
||||
model_kwargs = dict(
|
||||
n_threads=os.cpu_count() // 2,
|
||||
temp=kwargs.get("temperature", 0.2),
|
||||
top_p=kwargs.get("top_p", 0.75),
|
||||
top_k=kwargs.get("top_k", 40),
|
||||
n_ctx=2048 - 256,
|
||||
)
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
||||
# make int or float if can to satisfy types for class
|
||||
for k, v in model_kwargs.items():
|
||||
try:
|
||||
if float(v) == int(v):
|
||||
model_kwargs[k] = int(v)
|
||||
else:
|
||||
model_kwargs[k] = float(v)
|
||||
except:
|
||||
pass
|
||||
|
||||
if base_model == "llama":
|
||||
if "model_path_llama" not in model_kwargs:
|
||||
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
||||
model_path = model_kwargs.pop("model_path_llama")
|
||||
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# llama sets some things at init model time, not generation time
|
||||
func_names = list(inspect.signature(Llama.__init__).parameters)
|
||||
model_kwargs = {
|
||||
k: v for k, v in model_kwargs.items() if k in func_names
|
||||
}
|
||||
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
|
||||
model = Llama(model_path=model_path, **model_kwargs)
|
||||
elif base_model in "gpt4all_llama":
|
||||
if (
|
||||
"model_name_gpt4all_llama" not in model_kwargs
|
||||
and "model_path_gpt4all_llama" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gpt4all_llama")
|
||||
model_type = "llama"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
elif base_model in "gptj":
|
||||
if (
|
||||
"model_name_gptj" not in model_kwargs
|
||||
and "model_path_gptj" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4j or model_path_gpt4j in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gptj")
|
||||
model_type = "gptj"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
else:
|
||||
raise ValueError("No such base_model %s" % base_model)
|
||||
return model, FakeTokenizer(), "cpu"
|
||||
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
# streaming to std already occurs without this
|
||||
# sys.stdout.write(token)
|
||||
# sys.stdout.flush()
|
||||
pass
|
||||
|
||||
|
||||
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
||||
# default from class
|
||||
model_kwargs = {
|
||||
k: v.default
|
||||
for k, v in dict(inspect.signature(cls).parameters).items()
|
||||
if k not in exclude_list
|
||||
}
|
||||
# from our defaults
|
||||
model_kwargs.update(default_kwargs)
|
||||
# from user defaults
|
||||
model_kwargs.update(env_kwargs)
|
||||
# ensure only valid keys
|
||||
func_names = list(inspect.signature(cls).parameters)
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def get_llm_gpt4all(
|
||||
model_name,
|
||||
model=None,
|
||||
max_new_tokens=256,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
top_k=40,
|
||||
top_p=0.7,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
prompter=None,
|
||||
verbose=False,
|
||||
):
|
||||
assert prompter is not None
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
env_kwargs = dotenv_values(env_gpt4all_file)
|
||||
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
|
||||
default_kwargs = dict(
|
||||
context_erase=0.5,
|
||||
n_batch=1,
|
||||
n_ctx=n_ctx,
|
||||
n_predict=max_new_tokens,
|
||||
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
||||
repeat_penalty=repetition_penalty,
|
||||
temp=temperature,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
use_mlock=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
if model_name == "llama":
|
||||
cls = H2OLlamaCpp
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_llama") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model_path=model_path,
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
llm.client.verbose = verbose
|
||||
elif model_name == "gpt4all_llama":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gpt4all_llama")
|
||||
if model is None
|
||||
else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="llama",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
elif model_name == "gptj":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gptj") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="gptj",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
else:
|
||||
raise RuntimeError("No such model_name %s" % model_name)
|
||||
return llm
|
||||
|
||||
|
||||
class H2OGPT4All(gpt4all.GPT4All):
|
||||
model: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in the environment."""
|
||||
try:
|
||||
if isinstance(values["model"], str):
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
full_path = values["model"]
|
||||
model_path, delimiter, model_name = full_path.rpartition("/")
|
||||
model_path += delimiter
|
||||
|
||||
values["client"] = GPT4AllModel(
|
||||
model_name=model_name,
|
||||
model_path=model_path or None,
|
||||
model_type=values["backend"],
|
||||
allow_download=False,
|
||||
)
|
||||
if values["n_threads"] is not None:
|
||||
# set n_threads
|
||||
values["client"].model.set_thread_count(
|
||||
values["n_threads"]
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model"]
|
||||
try:
|
||||
values["backend"] = values["client"].model_type
|
||||
except AttributeError:
|
||||
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
||||
values["backend"] = values["client"].model.model_type
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import gpt4all python package. "
|
||||
"Please install it with `pip install gpt4all`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# Roughly 4 chars per token if natural language
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
verbose = False
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
||||
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
||||
|
||||
|
||||
from langchain.llms import LlamaCpp
|
||||
|
||||
|
||||
class H2OLlamaCpp(LlamaCpp):
|
||||
model_path: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
if isinstance(values["model_path"], str):
|
||||
model_path = values["model_path"]
|
||||
model_param_names = [
|
||||
"lora_path",
|
||||
"lora_base",
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"use_mmap",
|
||||
"last_n_tokens_size",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model_path"]
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
verbose = False
|
||||
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
||||
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > self.n_ctx:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
||||
prompt = prompt[-self.n_ctx * chars_per_token :]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing tokens, assuming average of %s chars/token: %s"
|
||||
% chars_per_token,
|
||||
flush=True,
|
||||
)
|
||||
prompt_tokens2 = self.client.tokenize(
|
||||
b" " + prompt.encode("utf-8")
|
||||
)
|
||||
num_prompt_tokens2 = len(prompt_tokens2)
|
||||
print(
|
||||
"reduced tokens from %d -> %d"
|
||||
% (num_prompt_tokens, num_prompt_tokens2),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
|
||||
if self.streaming:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
||||
if text_callback:
|
||||
text_callback(prompt)
|
||||
text = ""
|
||||
for token in self.stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager
|
||||
):
|
||||
text_chunk = token["choices"][0]["text"]
|
||||
# self.stream already calls text_callback
|
||||
# if text_callback:
|
||||
# text_callback(text_chunk)
|
||||
text += text_chunk
|
||||
return text
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
params = {**params, **kwargs}
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
3137
apps/language_models/langchain/gpt_langchain.py
Normal file
3137
apps/language_models/langchain/gpt_langchain.py
Normal file
File diff suppressed because it is too large
Load Diff
93
apps/language_models/langchain/gradio_utils/grclient.py
Normal file
93
apps/language_models/langchain/gradio_utils/grclient.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import traceback
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from gradio_client.client import Job
|
||||
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
from gradio_client import Client
|
||||
|
||||
|
||||
class GradioClient(Client):
|
||||
"""
|
||||
Parent class of gradio client
|
||||
To handle automatically refreshing client if detect gradio server changed
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server_hash = self.get_server_hash()
|
||||
|
||||
def get_server_hash(self):
|
||||
"""
|
||||
Get server hash using super without any refresh action triggered
|
||||
Returns: git hash of gradio server
|
||||
"""
|
||||
return super().submit(api_name="/system_hash").result()
|
||||
|
||||
def refresh_client_if_should(self):
|
||||
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
||||
# FIXME: Could add cli api as hash
|
||||
server_hash = self.get_server_hash()
|
||||
if self.server_hash != server_hash:
|
||||
self.refresh_client()
|
||||
self.server_hash = server_hash
|
||||
else:
|
||||
self.reset_session()
|
||||
|
||||
def refresh_client(self):
|
||||
"""
|
||||
Ensure every client call is independent
|
||||
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
||||
Returns:
|
||||
"""
|
||||
# need session hash to be new every time, to avoid "generator already executing"
|
||||
self.reset_session()
|
||||
|
||||
client = Client(*self.args, **self.kwargs)
|
||||
for k, v in client.__dict__.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
*args,
|
||||
api_name: str | None = None,
|
||||
fn_index: int | None = None,
|
||||
result_callbacks: Callable | list[Callable] | None = None,
|
||||
) -> Job:
|
||||
# Note predict calls submit
|
||||
try:
|
||||
self.refresh_client_if_should()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
except Exception as e:
|
||||
print("Hit e=%s" % str(e), flush=True)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
|
||||
# see if immediately failed
|
||||
e = job.future._exception
|
||||
if e is not None:
|
||||
print(
|
||||
"GR job failed: %s %s"
|
||||
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
|
||||
flush=True,
|
||||
)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
e2 = job.future._exception
|
||||
if e2 is not None:
|
||||
print(
|
||||
"GR job failed again: %s\n%s"
|
||||
% (
|
||||
str(e2),
|
||||
"".join(traceback.format_tb(e2.__traceback__)),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return job
|
||||
765
apps/language_models/langchain/h2oai_pipeline.py
Normal file
765
apps/language_models/langchain/h2oai_pipeline.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
if not args.run_docuchat_web:
|
||||
args.device = global_device
|
||||
args.precision = global_precision
|
||||
tensor_device = "cpu" if args.device == "cpu" else "cuda"
|
||||
|
||||
|
||||
class H2OGPTModel(torch.nn.Module):
|
||||
def __init__(self, device, precision):
|
||||
super().__init__()
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if precision == "fp32" or device == "cpu"
|
||||
else torch.float16
|
||||
)
|
||||
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
|
||||
model_kwargs = {
|
||||
"local_files_only": False,
|
||||
"torch_dtype": torch_dtype,
|
||||
"resume_download": True,
|
||||
"use_auth_token": False,
|
||||
"trust_remote_code": True,
|
||||
"offload_folder": "offline_folder",
|
||||
"device_map": device_map,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder="offline_folder",
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
config=config,
|
||||
**model_kwargs,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model.transformer.h,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return output.logits[:, -1, :]
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
extended_model_name = (
|
||||
model_name + "_" + args.precision + "_" + args.device
|
||||
)
|
||||
vmfb_path = Path(extended_model_name + ".vmfb")
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
need_to_compile = False
|
||||
if not vmfb_path.exists():
|
||||
need_to_compile = True
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Trying to download pre-compiled vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print(
|
||||
"Pre-compiled vmfb downloaded from shark tank successfully."
|
||||
)
|
||||
need_to_compile = False
|
||||
|
||||
if need_to_compile:
|
||||
if not mlir_path.exists():
|
||||
print("Trying to download pre-generated mlir from shark tank.")
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Generating the mlir
|
||||
bytecode = self.get_bytecode(tensor_device, args.precision)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
def get_bytecode(self, device, precision):
|
||||
h2ogpt_model = H2OGPTModel(device, precision)
|
||||
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 400)
|
||||
).to(device=device)
|
||||
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
h2ogptCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
h2ogpt_model,
|
||||
h2ogptCompileInput,
|
||||
is_f16=False,
|
||||
precision=precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del h2ogpt_model
|
||||
del self.src_model
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if precision in ["int4", "int8"]:
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = torch.from_numpy(
|
||||
self.model(
|
||||
"forward",
|
||||
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
|
||||
)
|
||||
).to(device=tensor_device)
|
||||
return result
|
||||
|
||||
|
||||
def decode_tokens(tokenizer, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
|
||||
return res_str
|
||||
|
||||
|
||||
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
truncated_input_ids = []
|
||||
stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(model.config)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
model_kwargs,
|
||||
) = model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
res_tokens = []
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_shark_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = next_token * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
|
||||
|
||||
model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
truncated_input_ids.append(input_ids[:, 0])
|
||||
input_ids = input_ids[:, 1:]
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
|
||||
|
||||
new_word = tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
res_tokens.append(next_token)
|
||||
if new_word == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{new_word}", end=" ", flush=True)
|
||||
|
||||
part_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield part_str
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield res_str
|
||||
|
||||
|
||||
def pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=400, do_truncation=False
|
||||
):
|
||||
inp_shape = input_ids.shape
|
||||
if inp_shape[1] < max_padding_length:
|
||||
# do padding
|
||||
num_add_token = max_padding_length - inp_shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
|
||||
input_ids,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
padded_attention_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
|
||||
attention_mask,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return padded_input_ids, padded_attention_mask
|
||||
elif inp_shape[1] > max_padding_length or do_truncation:
|
||||
# do truncation
|
||||
num_remove_token = inp_shape[1] - max_padding_length
|
||||
truncated_input_ids = input_ids[:, num_remove_token:]
|
||||
truncated_attention_mask = attention_mask[:, num_remove_token:]
|
||||
return truncated_input_ids, truncated_attention_mask
|
||||
else:
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
debug=False,
|
||||
chat=False,
|
||||
stream_output=False,
|
||||
sanitize_bot_response=False,
|
||||
use_prompter=True,
|
||||
prompter=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
max_input_tokens=2048 - 256,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
||||
:param args:
|
||||
:param debug:
|
||||
:param chat:
|
||||
:param stream_output:
|
||||
:param sanitize_bot_response:
|
||||
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
||||
:param prompter: prompter, can pass if have already
|
||||
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
||||
If use_prompter, then will make prompter and use it.
|
||||
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
||||
:param max_input_tokens:
|
||||
:param kwargs:
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_text = None
|
||||
self.use_prompter = use_prompter
|
||||
self.prompt_type = prompt_type
|
||||
self.prompt_dict = prompt_dict
|
||||
self.prompter = prompter
|
||||
if self.use_prompter:
|
||||
if self.prompter is not None:
|
||||
assert self.prompter.prompt_type is not None
|
||||
else:
|
||||
self.prompter = Prompter(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
debug=debug,
|
||||
chat=chat,
|
||||
stream_output=stream_output,
|
||||
)
|
||||
self.human = self.prompter.humanstr
|
||||
self.bot = self.prompter.botstr
|
||||
self.can_stop = True
|
||||
else:
|
||||
self.prompter = None
|
||||
self.human = None
|
||||
self.bot = None
|
||||
self.can_stop = False
|
||||
self.sanitize_bot_response = sanitize_bot_response
|
||||
self.max_input_tokens = (
|
||||
max_input_tokens # not for generate, so ok that not kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
||||
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
|
||||
|
||||
if hasattr(tokenizer, "model_max_length"):
|
||||
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
||||
model_max_length = tokenizer.model_max_length
|
||||
if max_prompt_length is not None:
|
||||
model_max_length = min(model_max_length, max_prompt_length)
|
||||
# cut at some upper likely limit to avoid excessive tokenization etc
|
||||
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
||||
if len(prompt_text) > model_max_length * 10:
|
||||
len0 = len(prompt_text)
|
||||
prompt_text = prompt_text[-model_max_length * 10 :]
|
||||
if verbose:
|
||||
print(
|
||||
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
# unknown
|
||||
model_max_length = None
|
||||
|
||||
num_prompt_tokens = None
|
||||
if model_max_length is not None:
|
||||
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
||||
# For https://github.com/h2oai/h2ogpt/issues/192
|
||||
for trial in range(0, 3):
|
||||
prompt_tokens = tokenizer(prompt_text)["input_ids"]
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > model_max_length:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
||||
# keep tail, where question is if using langchain
|
||||
prompt_text = prompt_text[
|
||||
-model_max_length * chars_per_token :
|
||||
]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing %s tokens, assuming average of %s chars/token for %s characters"
|
||||
% (
|
||||
num_prompt_tokens,
|
||||
chars_per_token,
|
||||
len(prompt_text),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"using %s tokens with %s chars"
|
||||
% (num_prompt_tokens, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
break
|
||||
|
||||
return prompt_text, num_prompt_tokens
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt_text,
|
||||
prefix="",
|
||||
handle_long_generation=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
(
|
||||
prompt_text,
|
||||
num_prompt_tokens,
|
||||
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
||||
|
||||
data_point = dict(context="", instruction=prompt_text, input="")
|
||||
if self.prompter is not None:
|
||||
prompt_text = self.prompter.generate_prompt(data_point)
|
||||
self.prompt_text = prompt_text
|
||||
if handle_long_generation is None:
|
||||
# forces truncation of inputs to avoid critical failure
|
||||
handle_long_generation = None # disable with new approaches
|
||||
return super().preprocess(
|
||||
prompt_text,
|
||||
prefix=prefix,
|
||||
handle_long_generation=handle_long_generation,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
):
|
||||
records = super().postprocess(
|
||||
model_outputs,
|
||||
return_type=return_type,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
for rec in records:
|
||||
if self.use_prompter:
|
||||
outputs = rec["generated_text"]
|
||||
outputs = self.prompter.get_response(
|
||||
outputs,
|
||||
prompt=self.prompt_text,
|
||||
sanitize_bot_response=self.sanitize_bot_response,
|
||||
)
|
||||
elif self.bot and self.human:
|
||||
outputs = (
|
||||
rec["generated_text"]
|
||||
.split(self.bot)[1]
|
||||
.split(self.human)[0]
|
||||
)
|
||||
else:
|
||||
outputs = rec["generated_text"]
|
||||
rec["generated_text"] = outputs
|
||||
print(
|
||||
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
|
||||
flush=True,
|
||||
)
|
||||
return records
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.can_stop:
|
||||
stopping_criteria = get_stopping(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
self.tokenizer,
|
||||
self.device,
|
||||
human=self.human,
|
||||
bot=self.bot,
|
||||
model_max_length=self.tokenizer.model_max_length,
|
||||
)
|
||||
generate_kwargs["stopping_criteria"] = stopping_criteria
|
||||
# return super()._forward(model_inputs, **generate_kwargs)
|
||||
return self.__forward(model_inputs, **generate_kwargs)
|
||||
|
||||
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
||||
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
||||
def __forward(self, model_inputs, **generate_kwargs):
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length")
|
||||
or self.model.config.max_length
|
||||
)
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
# pad or truncate the input_ids and attention_mask
|
||||
max_padding_length = 400
|
||||
input_ids, attention_mask = pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=max_padding_length
|
||||
)
|
||||
|
||||
return_dict = {
|
||||
"model": self.model,
|
||||
"tokenizer": self.tokenizer,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return_dict = {**return_dict, **generate_kwargs}
|
||||
return return_dict
|
||||
247
apps/language_models/langchain/image_captions.py
Normal file
247
apps/language_models/langchain/image_captions.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
|
||||
But accepts preloaded model to avoid slowness in use and CUDA forking issues
|
||||
|
||||
Loader that loads image captions
|
||||
By default, the loader utilizes the pre-trained BLIP image captioning model.
|
||||
https://huggingface.co/Salesforce/blip-image-captioning-base
|
||||
|
||||
"""
|
||||
from typing import List, Union, Any, Tuple
|
||||
|
||||
import requests
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import ImageCaptionLoader
|
||||
|
||||
from utils import get_device, NullContext
|
||||
|
||||
import pkg_resources
|
||||
|
||||
try:
|
||||
assert pkg_resources.get_distribution("bitsandbytes") is not None
|
||||
have_bitsandbytes = True
|
||||
except (pkg_resources.DistributionNotFound, AssertionError):
|
||||
have_bitsandbytes = False
|
||||
|
||||
|
||||
class H2OImageCaptionLoader(ImageCaptionLoader):
|
||||
"""Loader that loads the captions of an image"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_images: Union[str, List[str]] = None,
|
||||
blip_processor: str = None,
|
||||
blip_model: str = None,
|
||||
caption_gpu=True,
|
||||
load_in_8bit=True,
|
||||
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
|
||||
load_half=False,
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
min_new_tokens=20,
|
||||
max_tokens=50,
|
||||
):
|
||||
if blip_model is None or blip_model is None:
|
||||
blip_processor = "Salesforce/blip-image-captioning-base"
|
||||
blip_model = "Salesforce/blip-image-captioning-base"
|
||||
|
||||
super().__init__(path_images, blip_processor, blip_model)
|
||||
self.blip_processor = blip_processor
|
||||
self.blip_model = blip_model
|
||||
self.processor = None
|
||||
self.model = None
|
||||
self.caption_gpu = caption_gpu
|
||||
self.context_class = NullContext
|
||||
self.device = "cpu"
|
||||
self.load_in_8bit = (
|
||||
load_in_8bit and have_bitsandbytes
|
||||
) # only for blip2
|
||||
self.load_half = load_half
|
||||
self.load_gptq = load_gptq
|
||||
self.use_safetensors = use_safetensors
|
||||
self.gpu_id = "auto"
|
||||
# default prompt
|
||||
self.prompt = "image of"
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def set_context(self):
|
||||
if get_device() == "cuda" and self.caption_gpu:
|
||||
import torch
|
||||
|
||||
n_gpus = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
)
|
||||
if n_gpus > 0:
|
||||
self.context_class = torch.device
|
||||
self.device = "cuda"
|
||||
|
||||
def load_model(self):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please install with "
|
||||
"`pip install transformers`."
|
||||
)
|
||||
self.set_context()
|
||||
if self.caption_gpu:
|
||||
if self.gpu_id == "auto":
|
||||
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
|
||||
# device_map = 'auto'
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
if self.device == "cuda":
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if "blip2" in self.blip_processor.lower():
|
||||
from transformers import (
|
||||
Blip2Processor,
|
||||
Blip2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if self.load_half and not self.load_in_8bit:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
).half()
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
).half()
|
||||
)
|
||||
else:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
)
|
||||
else:
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
self.load_half = False # not supported
|
||||
if self.caption_gpu:
|
||||
if device_map == "auto":
|
||||
# Blip doesn't support device_map='auto'
|
||||
if self.device == "cuda":
|
||||
if self.gpu_id == "auto":
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
)
|
||||
self.model = (
|
||||
BlipForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def set_image_paths(self, path_images: Union[str, List[str]]):
|
||||
"""
|
||||
Load from a list of image files
|
||||
"""
|
||||
if isinstance(path_images, str):
|
||||
self.image_paths = [path_images]
|
||||
else:
|
||||
self.image_paths = path_images
|
||||
|
||||
def load(self, prompt=None) -> List[Document]:
|
||||
if self.processor is None or self.model is None:
|
||||
self.load_model()
|
||||
results = []
|
||||
for path_image in self.image_paths:
|
||||
caption, metadata = self._get_captions_and_metadata(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
path_image=path_image,
|
||||
prompt=prompt,
|
||||
)
|
||||
doc = Document(page_content=caption, metadata=metadata)
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def _get_captions_and_metadata(
|
||||
self, model: Any, processor: Any, path_image: str, prompt=None
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Helper function for getting the captions and metadata of an image
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = self.prompt
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`PIL` package not found, please install with `pip install pillow`"
|
||||
)
|
||||
|
||||
try:
|
||||
if path_image.startswith("http://") or path_image.startswith(
|
||||
"https://"
|
||||
):
|
||||
image = Image.open(
|
||||
requests.get(path_image, stream=True).raw
|
||||
).convert("RGB")
|
||||
else:
|
||||
image = Image.open(path_image).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError(f"Could not get image data for {path_image}")
|
||||
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if self.load_half:
|
||||
inputs = processor(
|
||||
image, prompt, return_tensors="pt"
|
||||
).half()
|
||||
else:
|
||||
inputs = processor(image, prompt, return_tensors="pt")
|
||||
min_length = len(prompt) // 4 + self.min_new_tokens
|
||||
self.max_tokens = max(self.max_tokens, min_length)
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
min_length=min_length,
|
||||
max_length=self.max_tokens,
|
||||
)
|
||||
|
||||
caption: str = processor.decode(
|
||||
output[0], skip_special_tokens=True
|
||||
)
|
||||
prompti = caption.find(prompt)
|
||||
if prompti >= 0:
|
||||
caption = caption[prompti + len(prompt) :]
|
||||
metadata: dict = {"image_path": path_image}
|
||||
|
||||
return caption, metadata
|
||||
120
apps/language_models/langchain/langchain_requirements.txt
Normal file
120
apps/language_models/langchain/langchain_requirements.txt
Normal file
@@ -0,0 +1,120 @@
|
||||
# for generate (gradio server) and finetune
|
||||
datasets==2.13.0
|
||||
sentencepiece==0.1.99
|
||||
huggingface_hub==0.16.4
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
scikit-learn==1.2.2
|
||||
alt-profanity-check==1.2.2
|
||||
better-profanity==0.7.0
|
||||
numpy==1.24.3
|
||||
pandas==2.0.2
|
||||
matplotlib==3.7.1
|
||||
loralib==0.1.1
|
||||
bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
peft==0.4.0
|
||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
# optional for generate
|
||||
pynvml==11.5.0
|
||||
psutil==5.9.5
|
||||
boto3==1.26.101
|
||||
botocore==1.29.101
|
||||
|
||||
# optional for finetune
|
||||
tensorboard==2.13.0
|
||||
neptune==1.2.0
|
||||
|
||||
# for gradio client
|
||||
gradio_client==0.2.10
|
||||
beautifulsoup4==4.12.2
|
||||
markdown==3.4.3
|
||||
|
||||
# data and testing
|
||||
pytest==7.2.2
|
||||
pytest-xdist==3.2.1
|
||||
nltk==3.8.1
|
||||
textstat==0.7.3
|
||||
# pandoc==2.3
|
||||
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
|
||||
pypandoc_binary==1.11; platform_machine == "x86_64"
|
||||
pypandoc_binary==1.11; sys_platform == "win32"
|
||||
openpyxl==3.1.2
|
||||
lm_dataformat==0.0.20
|
||||
bioc==2.0
|
||||
|
||||
# falcon
|
||||
einops==0.6.1
|
||||
instructorembedding==1.0.1
|
||||
|
||||
# for gpt4all .env file, but avoid worrying about imports
|
||||
python-dotenv==1.0.0
|
||||
|
||||
text-generation==0.6.0
|
||||
# for tokenization when don't have HF tokenizer
|
||||
tiktoken==0.4.0
|
||||
# optional: for OpenAI endpoint or embeddings (requires key)
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
# for HF embeddings
|
||||
sentence_transformers==2.2.2
|
||||
|
||||
# local vector db
|
||||
chromadb==0.3.25
|
||||
# server vector db
|
||||
#pymilvus==2.2.8
|
||||
|
||||
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
||||
# unstructured==0.8.1
|
||||
|
||||
# strong support for images
|
||||
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
unstructured[local-inference]==0.7.4
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
pillow
|
||||
|
||||
pdfminer.six==20221105
|
||||
urllib3
|
||||
requests_file
|
||||
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
tabulate==0.9.0
|
||||
# FYI pandoc already part of requirements.txt
|
||||
|
||||
# JSONLoader, but makes some trouble for some users
|
||||
# jq==1.4.1
|
||||
|
||||
# to check licenses
|
||||
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
||||
pip-licenses==4.3.0
|
||||
|
||||
# weaviate vector db
|
||||
weaviate-client==3.22.1
|
||||
|
||||
gpt4all==1.0.5
|
||||
llama-cpp-python==0.1.73
|
||||
|
||||
arxiv==1.4.8
|
||||
pymupdf==1.22.5 # AGPL license
|
||||
# extract-msg==0.41.1 # GPL3
|
||||
|
||||
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
|
||||
playwright==1.36.0
|
||||
# requires Chrome binary to be in path
|
||||
selenium==4.10.0
|
||||
124
apps/language_models/langchain/llama_flash_attn_monkey_patch.py
Normal file
124
apps/language_models/langchain/llama_flash_attn_monkey_patch.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
|
||||
]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask = attention_mask
|
||||
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
print(
|
||||
"Replacing original LLaMa attention with flash attention", flush=True
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
109
apps/language_models/langchain/loaders.py
Normal file
109
apps/language_models/langchain/loaders.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import functools
|
||||
|
||||
|
||||
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
|
||||
# NOTE: Some models need specific new prompt_type
|
||||
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
||||
if load_gptq:
|
||||
from transformers import AutoTokenizer
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
use_triton = False
|
||||
functools.partial(
|
||||
AutoGPTQForCausalLM.from_quantized,
|
||||
quantize_config=None,
|
||||
use_triton=use_triton,
|
||||
)
|
||||
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
|
||||
if llama_type is None:
|
||||
llama_type = "llama" in model_name.lower()
|
||||
if llama_type:
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
|
||||
elif "distilgpt2" in model_name.lower():
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
|
||||
elif "gpt2" in model_name.lower():
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
|
||||
elif "mbart-" in model_name.lower():
|
||||
from transformers import (
|
||||
MBartForConditionalGeneration,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
|
||||
return (
|
||||
MBartForConditionalGeneration.from_pretrained,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
elif (
|
||||
"t5" == model_name.lower()
|
||||
or "t5-" in model_name.lower()
|
||||
or "flan-" in model_name.lower()
|
||||
):
|
||||
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||
|
||||
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
|
||||
elif "bigbird" in model_name:
|
||||
from transformers import (
|
||||
BigBirdPegasusForConditionalGeneration,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
BigBirdPegasusForConditionalGeneration.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
elif (
|
||||
"bart-large-cnn-samsum" in model_name
|
||||
or "flan-t5-base-samsum" in model_name
|
||||
):
|
||||
from transformers import pipeline
|
||||
|
||||
return pipeline, "summarization"
|
||||
elif (
|
||||
reward_type
|
||||
or "OpenAssistant/reward-model".lower() in model_name.lower()
|
||||
):
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
AutoModelForSequenceClassification.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_loader = AutoModelForCausalLM
|
||||
tokenizer_loader = AutoTokenizer
|
||||
return model_loader.from_pretrained, tokenizer_loader
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_loader,
|
||||
tokenizer_base_model,
|
||||
local_files_only,
|
||||
resume_download,
|
||||
use_auth_token,
|
||||
):
|
||||
tokenizer = tokenizer_loader.from_pretrained(
|
||||
tokenizer_base_model,
|
||||
local_files_only=local_files_only,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
tokenizer.pad_token_id = 0 # different from the eos token
|
||||
# when generating, we will use the logits of right-most token to predict the next token
|
||||
# so the padding should be on the left,
|
||||
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
return tokenizer
|
||||
203
apps/language_models/langchain/make_db.py
Normal file
203
apps/language_models/langchain/make_db.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
get_some_dbs_from_hf,
|
||||
all_db_zips,
|
||||
some_db_zips,
|
||||
create_or_update_db,
|
||||
)
|
||||
from utils import get_ngpus_vis
|
||||
|
||||
|
||||
def glob_to_db(
|
||||
user_path,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
verbose=False,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
url=None,
|
||||
enable_captions=True,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
):
|
||||
sources1 = path_to_docs(
|
||||
user_path,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
return sources1
|
||||
|
||||
|
||||
def make_db_main(
|
||||
use_openai_embedding: bool = False,
|
||||
hf_embedding_model: str = None,
|
||||
persist_directory: str = "db_dir_UserData",
|
||||
user_path: str = "user_path",
|
||||
url: str = None,
|
||||
add_if_exists: bool = True,
|
||||
collection_name: str = "UserData",
|
||||
verbose: bool = False,
|
||||
chunk: bool = True,
|
||||
chunk_size: int = 512,
|
||||
fail_any_exception: bool = False,
|
||||
download_all: bool = False,
|
||||
download_some: bool = False,
|
||||
download_one: str = None,
|
||||
download_dest: str = "./",
|
||||
n_jobs: int = -1,
|
||||
enable_captions: bool = True,
|
||||
captions_model: str = "Salesforce/blip-image-captioning-base",
|
||||
pre_load_caption_model: bool = False,
|
||||
caption_gpu: bool = True,
|
||||
enable_ocr: bool = False,
|
||||
db_type: str = "chroma",
|
||||
):
|
||||
"""
|
||||
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
|
||||
python make_db.py
|
||||
|
||||
# once db is made, can use in generate.py like:
|
||||
|
||||
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
|
||||
|
||||
or zip-up the db_dir_UserData and share:
|
||||
|
||||
zip -r db_dir_UserData.zip db_dir_UserData
|
||||
|
||||
# To get all db files (except large wiki_full) do:
|
||||
python make_db.py --download_some=True
|
||||
|
||||
# To get a single db file from HF:
|
||||
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
|
||||
|
||||
:param use_openai_embedding: Whether to use OpenAI embedding
|
||||
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
:param persist_directory: where to persist db
|
||||
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
|
||||
:param url: url to generate documents from (None means user_path is not None)
|
||||
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
|
||||
:param collection_name: Collection name for new db if not adding
|
||||
:param verbose: whether to show verbose messages
|
||||
:param chunk: whether to chunk data
|
||||
:param chunk_size: chunk size for chunking
|
||||
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
|
||||
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
|
||||
:param download_some: whether to download some small example databases from h2o.ai HF
|
||||
:param download_one: whether to download one chosen example databases from h2o.ai HF
|
||||
:param download_dest: Destination for downloads
|
||||
:param n_jobs: Number of cores to use for ingesting multiple files
|
||||
:param enable_captions: Whether to enable captions on images
|
||||
:param captions_model: See generate.py
|
||||
:param pre_load_caption_model: See generate.py
|
||||
:param caption_gpu: Caption images on GPU if present
|
||||
:param enable_ocr: Whether to enable OCR on images
|
||||
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
|
||||
:return: None
|
||||
"""
|
||||
db = None
|
||||
|
||||
# match behavior of main() in generate.py for non-HF case
|
||||
n_gpus = get_ngpus_vis()
|
||||
if n_gpus == 0:
|
||||
if hf_embedding_model is None:
|
||||
# if no GPUs, use simpler embedding model to avoid cost in time
|
||||
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
else:
|
||||
if hf_embedding_model is None:
|
||||
# if still None, then set default
|
||||
hf_embedding_model = "hkunlp/instructor-large"
|
||||
|
||||
if download_all:
|
||||
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_some:
|
||||
print(
|
||||
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
|
||||
)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_one:
|
||||
print("Downloading %s (and unzipping)" % download_one, flush=True)
|
||||
get_some_dbs_from_hf(
|
||||
download_dest, db_zips=[[download_one, "", "Unknown License"]]
|
||||
)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
|
||||
if enable_captions and pre_load_caption_model:
|
||||
# preload, else can be too slow or if on GPU have cuda context issues
|
||||
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
|
||||
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
|
||||
from image_captions import H2OImageCaptionLoader
|
||||
|
||||
caption_loader = H2OImageCaptionLoader(
|
||||
None,
|
||||
blip_model=captions_model,
|
||||
blip_processor=captions_model,
|
||||
caption_gpu=caption_gpu,
|
||||
).load_model()
|
||||
else:
|
||||
if enable_captions:
|
||||
caption_loader = "gpu" if caption_gpu else "cpu"
|
||||
else:
|
||||
caption_loader = False
|
||||
|
||||
if verbose:
|
||||
print("Getting sources", flush=True)
|
||||
assert (
|
||||
user_path is not None or url is not None
|
||||
), "Can't have both user_path and url as None"
|
||||
if not url:
|
||||
assert os.path.isdir(user_path), (
|
||||
"user_path=%s does not exist" % user_path
|
||||
)
|
||||
sources = glob_to_db(
|
||||
user_path,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
exceptions = [x for x in sources if x.metadata.get("exception")]
|
||||
print("Exceptions: %s" % exceptions, flush=True)
|
||||
sources = [x for x in sources if "exception" not in x.metadata]
|
||||
|
||||
assert len(sources) > 0, "No sources found"
|
||||
db = create_or_update_db(
|
||||
db_type,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
sources,
|
||||
use_openai_embedding,
|
||||
add_if_exists,
|
||||
verbose,
|
||||
hf_embedding_model,
|
||||
)
|
||||
|
||||
assert db is not None
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
1103
apps/language_models/langchain/prompter.py
Normal file
1103
apps/language_models/langchain/prompter.py
Normal file
File diff suppressed because it is too large
Load Diff
403
apps/language_models/langchain/read_wiki_full.py
Normal file
403
apps/language_models/langchain/read_wiki_full.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Load Data from a MediaWiki dump xml."""
|
||||
import ast
|
||||
import glob
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import bz2
|
||||
import csv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import MWDumpLoader
|
||||
|
||||
# path where downloaded wiki files exist, to be processed
|
||||
root_path = "/data/jon/h2o-llm"
|
||||
|
||||
|
||||
def unescape(x):
|
||||
try:
|
||||
x = ast.literal_eval(x)
|
||||
except:
|
||||
try:
|
||||
x = x.encode("ascii", "ignore").decode("unicode_escape")
|
||||
except:
|
||||
pass
|
||||
return x
|
||||
|
||||
|
||||
def get_views():
|
||||
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
|
||||
views = pd.read_csv("wiki_page_views_more_5000month.csv")
|
||||
views.index = views["title"]
|
||||
views = views["views"]
|
||||
views = views.to_dict()
|
||||
views = {str(unescape(str(k))): v for k, v in views.items()}
|
||||
views2 = {k.replace("_", " "): v for k, v in views.items()}
|
||||
# views has _ but pages has " "
|
||||
views.update(views2)
|
||||
return views
|
||||
|
||||
|
||||
class MWDumpDirectLoader(MWDumpLoader):
|
||||
def __init__(
|
||||
self,
|
||||
data: str,
|
||||
encoding: Optional[str] = "utf8",
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
verbose=True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self.data = data
|
||||
self.encoding = encoding
|
||||
self.title_words_limit = title_words_limit
|
||||
self.verbose = verbose
|
||||
if use_views:
|
||||
# self.views = get_views()
|
||||
# faster to use global shared values
|
||||
self.views = global_views
|
||||
else:
|
||||
self.views = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
import mwparserfromhell
|
||||
import mwxml
|
||||
|
||||
dump = mwxml.Dump.from_page_xml(self.data)
|
||||
|
||||
docs = []
|
||||
|
||||
for page in dump.pages:
|
||||
if self.views is not None and page.title not in self.views:
|
||||
if self.verbose:
|
||||
print("Skipped %s low views" % page.title, flush=True)
|
||||
continue
|
||||
for revision in page:
|
||||
if self.title_words_limit is not None:
|
||||
num_words = len(" ".join(page.title.split("_")).split(" "))
|
||||
if num_words > self.title_words_limit:
|
||||
if self.verbose:
|
||||
print("Skipped %s" % page.title, flush=True)
|
||||
continue
|
||||
if self.verbose:
|
||||
if self.views is not None:
|
||||
print(
|
||||
"Kept %s views: %s"
|
||||
% (page.title, self.views[page.title]),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print("Kept %s" % page.title, flush=True)
|
||||
|
||||
code = mwparserfromhell.parse(revision.text)
|
||||
text = code.strip_code(
|
||||
normalize=True, collapse=True, keep_template_params=False
|
||||
)
|
||||
title_url = str(page.title).replace(" ", "_")
|
||||
metadata = dict(
|
||||
title=page.title,
|
||||
source="https://en.wikipedia.org/wiki/" + title_url,
|
||||
id=page.id,
|
||||
redirect=page.redirect,
|
||||
views=self.views[page.title]
|
||||
if self.views is not None
|
||||
else -1,
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def search_index(search_term, index_filename):
|
||||
byte_flag = False
|
||||
data_length = start_byte = 0
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
for line in csv_reader:
|
||||
if not byte_flag and search_term == line[2]:
|
||||
start_byte = int(line[0])
|
||||
byte_flag = True
|
||||
elif byte_flag and int(line[0]) != start_byte:
|
||||
data_length = int(line[0]) - start_byte
|
||||
break
|
||||
index_file.close()
|
||||
return start_byte, data_length
|
||||
|
||||
|
||||
def get_start_bytes(index_filename):
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
start_bytes = set()
|
||||
for line in csv_reader:
|
||||
start_bytes.add(int(line[0]))
|
||||
index_file.close()
|
||||
return sorted(start_bytes)
|
||||
|
||||
|
||||
def get_wiki_filenames():
|
||||
# requires
|
||||
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
|
||||
base_path = os.path.join(
|
||||
root_path, "enwiki-20230401-pages-articles-multistream"
|
||||
)
|
||||
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
|
||||
index_filename = os.path.join(base_path, index_file)
|
||||
wiki_filename = os.path.join(
|
||||
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
|
||||
)
|
||||
return index_filename, wiki_filename
|
||||
|
||||
|
||||
def get_documents_by_search_term(search_term):
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_byte, data_length = search_index(search_term, index_filename)
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(data.decode())
|
||||
documents = loader.load()
|
||||
return documents
|
||||
|
||||
|
||||
def get_one_chunk(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=True,
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
):
|
||||
data_length = end_byte - start_byte
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(
|
||||
data.decode(), title_words_limit=title_words_limit, use_views=use_views
|
||||
)
|
||||
documents1 = loader.load()
|
||||
if return_file:
|
||||
base_tmp = "temp_wiki"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(documents1, f)
|
||||
return filename
|
||||
return documents1
|
||||
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
global_views = get_views()
|
||||
|
||||
|
||||
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
|
||||
print("DO get all wiki docs: %s" % small_test, flush=True)
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_bytes = get_start_bytes(index_filename)
|
||||
end_bytes = start_bytes[1:]
|
||||
start_bytes = start_bytes[:-1]
|
||||
|
||||
if small_test:
|
||||
start_bytes = start_bytes[:small_test]
|
||||
end_bytes = end_bytes[:small_test]
|
||||
if n_jobs is None:
|
||||
n_jobs = 5
|
||||
else:
|
||||
if n_jobs is None:
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
# default loky backend leads to name space conflict problems
|
||||
return_file = True # large return from joblib hangs
|
||||
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
|
||||
delayed(get_one_chunk)(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=return_file,
|
||||
use_views=use_views,
|
||||
)
|
||||
for start_byte, end_byte in zip(start_bytes, end_bytes)
|
||||
)
|
||||
if return_file:
|
||||
# then documents really are files
|
||||
files = documents.copy()
|
||||
documents = []
|
||||
for fil in files:
|
||||
with open(fil, "rb") as f:
|
||||
documents.extend(pickle.load(f))
|
||||
os.remove(fil)
|
||||
else:
|
||||
from functools import reduce
|
||||
from operator import concat
|
||||
|
||||
documents = reduce(concat, documents)
|
||||
assert isinstance(documents, list)
|
||||
|
||||
print("DONE get all wiki docs", flush=True)
|
||||
return documents
|
||||
|
||||
|
||||
def test_by_search_term():
|
||||
search_term = "Apollo"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Abstract (law)"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Artificial languages"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
|
||||
def test_start_bytes():
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
assert len(get_start_bytes(index_filename)) == 227850
|
||||
|
||||
|
||||
def test_get_all_documents():
|
||||
small_test = 20 # 227850
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=False
|
||||
)
|
||||
)
|
||||
== small_test * 100
|
||||
)
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=True
|
||||
)
|
||||
)
|
||||
== 429
|
||||
)
|
||||
|
||||
|
||||
def get_one_pageviews(fil):
|
||||
df1 = pd.read_csv(
|
||||
fil,
|
||||
sep=" ",
|
||||
header=None,
|
||||
names=["region", "title", "views", "foo"],
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
df1.index = df1["title"]
|
||||
df1 = df1[df1["region"] == "en"]
|
||||
df1 = df1.drop("region", axis=1)
|
||||
df1 = df1.drop("foo", axis=1)
|
||||
df1 = df1.drop("title", axis=1) # already index
|
||||
|
||||
base_tmp = "temp_wiki_pageviews"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
|
||||
df1.to_csv(filename, index=True)
|
||||
return filename
|
||||
|
||||
|
||||
def test_agg_pageviews(gen_files=False):
|
||||
if gen_files:
|
||||
path = os.path.join(
|
||||
root_path,
|
||||
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
|
||||
)
|
||||
files = glob.glob(os.path.join(path, "pageviews*.gz"))
|
||||
# files = files[:2] # test
|
||||
n_jobs = os.cpu_count() // 2
|
||||
csv_files = Parallel(
|
||||
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
|
||||
)(delayed(get_one_pageviews)(fil) for fil in files)
|
||||
else:
|
||||
# to continue without redoing above
|
||||
csv_files = glob.glob(
|
||||
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
|
||||
)
|
||||
|
||||
df_list = []
|
||||
for csv_file in csv_files:
|
||||
print(csv_file)
|
||||
df1 = pd.read_csv(csv_file)
|
||||
df_list.append(df1)
|
||||
df = pd.concat(df_list, axis=0)
|
||||
df = df.groupby("title")["views"].sum().reset_index()
|
||||
df.to_csv("wiki_page_views.csv", index=True)
|
||||
|
||||
|
||||
def test_reduce_pageview():
|
||||
filename = "wiki_page_views.csv"
|
||||
df = pd.read_csv(filename)
|
||||
df = df[df["views"] < 1e7]
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
#
|
||||
views_limit = 5000
|
||||
df = df[df["views"] > views_limit]
|
||||
filename = "wiki_page_views_more_5000month.csv"
|
||||
df.to_csv(filename, index=True)
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
|
||||
|
||||
@pytest.mark.skip("Only if doing full processing again, some manual steps")
|
||||
def test_do_wiki_full_all():
|
||||
# Install other requirements for wiki specific conversion:
|
||||
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
|
||||
|
||||
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
|
||||
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
|
||||
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
|
||||
|
||||
# Get index
|
||||
os.system(
|
||||
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
|
||||
)
|
||||
|
||||
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
|
||||
test_get_all_documents()
|
||||
|
||||
# Check can search wiki multistream
|
||||
test_by_search_term()
|
||||
|
||||
# Test can get all start bytes in index
|
||||
test_start_bytes()
|
||||
|
||||
# Get page views, e.g. for entire month of April 2023
|
||||
os.system(
|
||||
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
|
||||
)
|
||||
|
||||
# Aggregate page views from many files into single file
|
||||
test_agg_pageviews(gen_files=True)
|
||||
|
||||
# Reduce page views to some limit, so processing of full wiki is not too large
|
||||
test_reduce_pageview()
|
||||
|
||||
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
|
||||
# Note get_views as global() function done once is required to avoid very slow processing
|
||||
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
|
||||
"""
|
||||
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
|
||||
"""
|
||||
121
apps/language_models/langchain/stopping.py
Normal file
121
apps/language_models/langchain/stopping.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from enums import PromptType
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(
|
||||
self, stops=[], encounters=[], device="cuda", model_max_length=None
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
len(stops) % len(encounters) == 0
|
||||
), "Number of stops and encounters must match"
|
||||
self.encounters = encounters
|
||||
self.stops = [stop.to(device) for stop in stops]
|
||||
self.num_stops = [0] * len(stops)
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stopi, stop in enumerate(self.stops):
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
self.num_stops[stopi] += 1
|
||||
if (
|
||||
self.num_stops[stopi]
|
||||
>= self.encounters[stopi % len(self.encounters)]
|
||||
):
|
||||
# print("Stopped", flush=True)
|
||||
return True
|
||||
if (
|
||||
self.model_max_length is not None
|
||||
and input_ids[0].shape[0] >= self.model_max_length
|
||||
):
|
||||
# critical limit
|
||||
return True
|
||||
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
||||
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_stopping(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
tokenizer,
|
||||
device,
|
||||
human="<human>:",
|
||||
bot="<bot>:",
|
||||
model_max_length=None,
|
||||
):
|
||||
# FIXME: prompt_dict unused currently
|
||||
if prompt_type in [
|
||||
PromptType.human_bot.name,
|
||||
PromptType.instruct_vicuna.name,
|
||||
PromptType.instruct_with_end.name,
|
||||
]:
|
||||
if prompt_type == PromptType.human_bot.name:
|
||||
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
||||
# stopping only starts once output is beyond prompt
|
||||
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
||||
stop_words = [human, bot, "\n" + human, "\n" + bot]
|
||||
encounters = [1, 2]
|
||||
elif prompt_type == PromptType.instruct_vicuna.name:
|
||||
# even below is not enough, generic strings and many ways to encode
|
||||
stop_words = [
|
||||
"### Human:",
|
||||
"""
|
||||
### Human:""",
|
||||
"""
|
||||
### Human:
|
||||
""",
|
||||
"### Assistant:",
|
||||
"""
|
||||
### Assistant:""",
|
||||
"""
|
||||
### Assistant:
|
||||
""",
|
||||
]
|
||||
encounters = [1, 2]
|
||||
else:
|
||||
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
||||
stop_words = ["### End"]
|
||||
encounters = [1]
|
||||
stop_words_ids = [
|
||||
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
|
||||
for stop_word in stop_words
|
||||
]
|
||||
# handle single token case
|
||||
stop_words_ids = [
|
||||
x if len(x.shape) > 0 else torch.tensor([x])
|
||||
for x in stop_words_ids
|
||||
]
|
||||
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
||||
# avoid padding in front of tokens
|
||||
if (
|
||||
tokenizer._pad_token
|
||||
): # use hidden variable to avoid annoying properly logger bug
|
||||
stop_words_ids = [
|
||||
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
|
||||
for x in stop_words_ids
|
||||
]
|
||||
# handle fake \n added
|
||||
stop_words_ids = [
|
||||
x[1:] if y[0] == "\n" else x
|
||||
for x, y in zip(stop_words_ids, stop_words)
|
||||
]
|
||||
# build stopper
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
StoppingCriteriaSub(
|
||||
stops=stop_words_ids,
|
||||
encounters=encounters,
|
||||
device=device,
|
||||
model_max_length=model_max_length,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
return stopping_criteria
|
||||
1070
apps/language_models/langchain/utils.py
Normal file
1070
apps/language_models/langchain/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
69
apps/language_models/langchain/utils_langchain.py
Normal file
69
apps/language_models/langchain/utils_langchain.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
import time
|
||||
import queue
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: Optional[float] = None, block=True):
|
||||
super().__init__()
|
||||
self.text_queue = queue.SimpleQueue()
|
||||
self.stop_signal = None
|
||||
self.do_stop = False
|
||||
self.timeout = timeout
|
||||
self.block = block
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running. Clean the queue."""
|
||||
while not self.text_queue.empty():
|
||||
try:
|
||||
self.text_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self.text_queue.put(token)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
value = (
|
||||
self.stop_signal
|
||||
) # value looks unused in pycharm, not true
|
||||
if self.do_stop:
|
||||
print("hit stop", flush=True)
|
||||
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
||||
raise StopIteration()
|
||||
# break
|
||||
value = self.text_queue.get(
|
||||
block=self.block, timeout=self.timeout
|
||||
)
|
||||
break
|
||||
except queue.Empty:
|
||||
time.sleep(0.01)
|
||||
if value == self.stop_signal:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return value
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
@@ -1,27 +1,15 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
from threading import Thread
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
@@ -51,142 +39,37 @@ def user(message, history):
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def get_torch_mlir_module_bytecode(model, model_inputs):
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
# tracing_mode='symbolic',
|
||||
)(*model_inputs)
|
||||
print("Got FX_G")
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
print("Got TS_G")
|
||||
return ts_g
|
||||
|
||||
|
||||
def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
# ADD Device Arg
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
vmfb_path = Path(model_vmfb_name + ".vmfb")
|
||||
if vmfb_path.exists():
|
||||
print("Loading ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device="cuda", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path { mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path) as f:
|
||||
bytecode = f.read("rb")
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
@@ -205,13 +88,13 @@ def compile_stableLM(model, model_inputs, model_name, model_vmfb_name):
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
import os
|
||||
|
||||
path = shark_module.save_module(os.getcwd(), model_vmfb_name, [])
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
@@ -244,60 +127,85 @@ def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print(f"Sucessfully loaded the tokenizer to the memory")
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
streamer,
|
||||
max_new_tokens,
|
||||
do_sample,
|
||||
top_p,
|
||||
top_k,
|
||||
temperature,
|
||||
num_beams,
|
||||
stopping_criteria,
|
||||
sharkStableLM,
|
||||
tok=None,
|
||||
input_ids=torch.randint(3, (1, 256)),
|
||||
attention_mask=torch.randint(3, (1, 256)),
|
||||
tokenizer=None,
|
||||
):
|
||||
if tok == None:
|
||||
tok = get_tokenizer
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
numWords = len(new_text.split())
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
model_inputs = tok(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = sharkStableLM(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
if shouldStop(next_toks.indices):
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
# streamer.put(next_toks.indices[0][int(sum_attentionmask)-1])
|
||||
new_word = tok.decode(
|
||||
next_toks.indices[0][int(sum_attentionmask) - 1],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
print(new_word, end="", flush=True)
|
||||
words_list.append(new_word)
|
||||
if new_word == "":
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + new_word
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,777 +0,0 @@
|
||||
import sys
|
||||
import warnings
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import transform_fx as transform_fx_
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class CompiledFirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompiledSecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers0, layers1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert len(layers0) == len(model.model.layers)
|
||||
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers0 = layers0
|
||||
self.layers1 = layers1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
if is_first:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers0
|
||||
)
|
||||
return self.model.forward(input_ids, attention_mask=attention_mask)
|
||||
else:
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers1
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, dynamic_input_size):
|
||||
new_lines = []
|
||||
for line in module.splitlines():
|
||||
if "dim_42 =" in line:
|
||||
continue
|
||||
if f"%c{dynamic_input_size}_i64 =" in line:
|
||||
new_lines.append(
|
||||
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
|
||||
)
|
||||
new_lines.append(
|
||||
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
|
||||
)
|
||||
continue
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim_42,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim_42\)",
|
||||
"tensor.empty(%dim_42, %dim_42)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim_42", line)
|
||||
new_lines.append(line)
|
||||
new_module = "\n".join(new_lines)
|
||||
return new_module
|
||||
|
||||
|
||||
def compile_vicuna_layer(
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0=None,
|
||||
past_key_value1=None,
|
||||
):
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
position_ids, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
if past_key_value0 is None and past_key_value1 is None:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(hidden_states, attention_mask, position_ids)
|
||||
|
||||
else:
|
||||
fx_g = make_fx(
|
||||
vicuna_layer,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
)
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
transform_fx(fx_g)
|
||||
fx_g.recompile()
|
||||
removed_none_indexes = _remove_nones(fx_g)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_g)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
print("FX_G recompile")
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
return ts_g
|
||||
|
||||
|
||||
path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
mlir_path = Path(f"{idx}_0.mlir")
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
else:
|
||||
mlir_path = Path(f"{idx}_1.mlir")
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if vmfb_path.exists():
|
||||
continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
else:
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
inputs[0], dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
inputs[1], dynamic_axes=[3]
|
||||
)
|
||||
position_ids_placeholder = TensorPlaceholder.like(
|
||||
inputs[2], dynamic_axes=[1]
|
||||
)
|
||||
if not is_first:
|
||||
pkv0_placeholder = TensorPlaceholder.like(
|
||||
inputs[3], dynamic_axes=[2]
|
||||
)
|
||||
pkv1_placeholder = TensorPlaceholder.like(
|
||||
inputs[4], dynamic_axes=[2]
|
||||
)
|
||||
print(f"Compiling layer {idx} mlir")
|
||||
if is_first:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer, inputs[0], inputs[1], inputs[2]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
ts_g = compile_vicuna_layer(
|
||||
layer,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4],
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
inputs[0],
|
||||
attention_mask_placeholder,
|
||||
inputs[2],
|
||||
pkv0_placeholder,
|
||||
pkv1_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
|
||||
if is_first:
|
||||
module = write_in_dynamic_inputs0(str(module), 137)
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
else:
|
||||
module = write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
bytecode = bytecode_stream.read()
|
||||
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "vulkan"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
extra_args=[
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
|
||||
def get_sharded_model():
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
global vicuna_model
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64),
|
||||
)
|
||||
|
||||
placeholder_input1 = (
|
||||
torch.zeros([1, 1, 4096]),
|
||||
torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]),
|
||||
torch.zeros([1, 1], dtype=torch.int64),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
|
||||
layers0 = [FirstVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules0 = compile_to_vmfb(placeholder_input0, layers0, is_first=True)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
layers1 = [SecondVicunaLayer(layer) for layer in vicuna_model.model.layers]
|
||||
_, modules1 = compile_to_vmfb(placeholder_input1, layers1, is_first=False)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model, shark_layers0, shark_layers1
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
|
||||
sharded_model = get_sharded_model()
|
||||
|
||||
|
||||
def user(message, history):
|
||||
print("msg=", message)
|
||||
print("history=", history)
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def chat(curr_system_message, history):
|
||||
global sharded_model
|
||||
past_key_values = None
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
print(messages)
|
||||
prompt = messages.strip()
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
tokens = input_ids
|
||||
new_sentence = []
|
||||
max_response_len = 1000
|
||||
partial_sentence = []
|
||||
partial_text = ""
|
||||
start_time = time.time()
|
||||
for iteration in range(max_response_len):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
|
||||
if iteration == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_token = int(torch.argmax(logits[:, -1, :], dim=1)[0])
|
||||
if new_token == 2:
|
||||
break
|
||||
new_sentence += [new_token]
|
||||
partial_sentence += [new_token]
|
||||
if iteration > 0 and iteration % 2 == 0:
|
||||
new_text = tokenizer.decode(partial_sentence)
|
||||
partial_sentence = []
|
||||
print(new_text, " ")
|
||||
partial_text += new_text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
tokens.append(new_token)
|
||||
original_input_ids.append(new_token)
|
||||
input_ids = [new_token]
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Total time taken to generated response is {end_time-start_time} seconds"
|
||||
)
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if type(tokens[i]) != int:
|
||||
tokens[i] = int(tokens[i][0])
|
||||
new_sentence_str = tokenizer.decode(new_sentence)
|
||||
print(new_sentence_str)
|
||||
history[-1][1] = new_sentence_str
|
||||
return history
|
||||
|
||||
|
||||
system_msg = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
# history_eg = [["hi hello how are you", ""]]
|
||||
# print(chat(system_msg, history_eg))
|
||||
|
||||
with gr.Blocks(title="Chatbot") as vicuna_chat:
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
).style(container=False)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
system_msg = gr.Textbox(
|
||||
system_msg, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
args, unknown = p.parse_known_args()
|
||||
|
||||
vicuna_chat.queue()
|
||||
vicuna_chat.launch(
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.server_port,
|
||||
)
|
||||
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FalconModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)[0]
|
||||
return output[:, -1, :]
|
||||
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import torch
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ln_vision,
|
||||
visual_encoder,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_vision = ln_vision
|
||||
self.visual_encoder = visual_encoder
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Vision Model applying weight quantization to ln_vision")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.ln_vision,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
print(
|
||||
"Vision Model applying weight quantization to visual_encoder"
|
||||
)
|
||||
quantize_model(
|
||||
self.visual_encoder,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, image):
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
return image_embeds
|
||||
|
||||
|
||||
class QformerBertModel(torch.nn.Module):
|
||||
def __init__(self, qformer_bert):
|
||||
super().__init__()
|
||||
self.qformer_bert = qformer_bert
|
||||
|
||||
def forward(self, query_tokens, image_embeds, image_atts):
|
||||
query_output = self.qformer_bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
return query_output.last_hidden_state
|
||||
|
||||
|
||||
class FirstLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("First Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, inputs_embeds, position_ids, attention_mask):
|
||||
print("************************************")
|
||||
print(
|
||||
"inputs_embeds: ",
|
||||
inputs_embeds.shape,
|
||||
" dtype: ",
|
||||
inputs_embeds.dtype,
|
||||
)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("************************************")
|
||||
config = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Second Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
print("************************************")
|
||||
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
|
||||
print("past_key_values dtype: ", i1.dtype)
|
||||
print("************************************")
|
||||
config = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
),
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
876
apps/language_models/src/model_wrappers/vicuna4.py
Normal file
876
apps/language_models/src/model_wrappers/vicuna4.py
Normal file
@@ -0,0 +1,876 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import iree.runtime
|
||||
import itertools
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaRMSNorm,
|
||||
_make_causal_mask,
|
||||
_expand_mask,
|
||||
)
|
||||
from torch import nn
|
||||
from time import time
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
t1 = time()
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = (
|
||||
use_cache if use_cache is not None else self.config.use_cache
|
||||
)
|
||||
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = (
|
||||
seq_length_with_past + past_key_values_length
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device
|
||||
if input_ids is not None
|
||||
else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer.forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[1:],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
_ = 10
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = tuple(itertools.chain.from_iterable(next_cache))
|
||||
print(f"Token generated in {time() - t1} seconds")
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
):
|
||||
pkvs = [
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
]
|
||||
new_pkvs = []
|
||||
for layer, pkv in zip(self.layers, pkvs):
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
pkv[0],
|
||||
pkv[1],
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerFV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
new_pkvs = []
|
||||
for layer in self.layers:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
pkv00 = pkv00.detatch()
|
||||
pkv01 = pkv01.detatch()
|
||||
pkv10 = pkv10.detatch()
|
||||
pkv11 = pkv11.detatch()
|
||||
pkv20 = pkv20.detatch()
|
||||
pkv21 = pkv21.detatch()
|
||||
pkv30 = pkv30.detatch()
|
||||
pkv31 = pkv31.detatch()
|
||||
pkv40 = pkv40.detatch()
|
||||
pkv41 = pkv41.detatch()
|
||||
pkv50 = pkv50.detatch()
|
||||
pkv51 = pkv51.detatch()
|
||||
pkv60 = pkv60.detatch()
|
||||
pkv61 = pkv61.detatch()
|
||||
pkv70 = pkv70.detatch()
|
||||
pkv71 = pkv71.detatch()
|
||||
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
(output[1][0], output[1][1]),
|
||||
(output[2][0], output[2][1]),
|
||||
(output[3][0], output[3][1]),
|
||||
(output[4][0], output[4][1]),
|
||||
(output[5][0], output[5][1]),
|
||||
(output[6][0], output[6][1]),
|
||||
(output[7][0], output[7][1]),
|
||||
(output[8][0], output[8][1]),
|
||||
)
|
||||
|
||||
|
||||
def forward_compressed(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (
|
||||
layer_outputs[2 if output_attentions else 1],
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
t2 = time()
|
||||
if past_key_value is None:
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
pass
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
t1 = time()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(hidden_states, attention_mask, position_ids),
|
||||
send_to_host=False,
|
||||
)
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
else:
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
|
||||
try:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv00 = pkv00.detach()
|
||||
pkv01 = pkv01.detach()
|
||||
pkv10 = pkv10.detach()
|
||||
pkv11 = pkv11.detach()
|
||||
pkv20 = pkv20.detach()
|
||||
pkv21 = pkv21.detach()
|
||||
pkv30 = pkv30.detach()
|
||||
pkv31 = pkv31.detach()
|
||||
pkv40 = pkv40.detach()
|
||||
pkv41 = pkv41.detach()
|
||||
pkv50 = pkv50.detach()
|
||||
pkv51 = pkv51.detach()
|
||||
pkv60 = pkv60.detach()
|
||||
pkv61 = pkv61.detach()
|
||||
pkv70 = pkv70.detach()
|
||||
pkv71 = pkv71.detach()
|
||||
except:
|
||||
x = 10
|
||||
|
||||
t1 = time()
|
||||
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
|
||||
hidden_states = np.array(hidden_states, hidden_states.dtype)
|
||||
hidden_states = torch.tensor(hidden_states)
|
||||
hidden_states = hidden_states.detach()
|
||||
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
print(f"{time() - t1}")
|
||||
del pkv00
|
||||
del pkv01
|
||||
del pkv10
|
||||
del pkv11
|
||||
del pkv20
|
||||
del pkv21
|
||||
del pkv30
|
||||
del pkv31
|
||||
del pkv40
|
||||
del pkv41
|
||||
del pkv50
|
||||
del pkv51
|
||||
del pkv60
|
||||
del pkv61
|
||||
del pkv70
|
||||
del pkv71
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
1170
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
1170
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
File diff suppressed because it is too large
Load Diff
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1165
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
231
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
231
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
# assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
self.norm = norm
|
||||
self.embedding = embedding
|
||||
self.lmhead = lmhead
|
||||
self.model.model.norm = self.norm
|
||||
self.model.model.embed_tokens = self.embedding
|
||||
self.model.lm_head = self.lmhead
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class LMHead(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class LMHeadCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNorm(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNormCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model(input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
44
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
44
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
567
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
567
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
@@ -0,0 +1,567 @@
|
||||
from apps.language_models.src.model_wrappers.falcon_model import FalconModel
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
description="runs a falcon model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--falcon_vmfb_path", default=None, help="path to falcon's vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--falcon_mlir_path",
|
||||
default=None,
|
||||
help="path to falcon's mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cli",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model in cli mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for falcon-180B model.",
|
||||
)
|
||||
|
||||
|
||||
class Falcon(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="tiiuae/falcon-7b-instruct",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=150,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
print("hf_model_path: ", self.hf_model_path)
|
||||
|
||||
if "180b" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for falcon-180b. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
token=self.hf_auth_token,
|
||||
)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token_id = 11
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
if args.use_precompiled_model:
|
||||
if not self.falcon_vmfb_path.exists():
|
||||
# Downloading VMFB from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ "_"
|
||||
+ self.device
|
||||
+ ".vmfb",
|
||||
self.falcon_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
|
||||
with open(self.falcon_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
mlir_generated = False
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
if args.load_mlir_from_shark_tank:
|
||||
# Downloading MLIR from shark_tank
|
||||
print(f"[DEBUG] Trying to download mlir from shark_tank")
|
||||
download_public_file(
|
||||
"gs://shark_tank/falcon/"
|
||||
+ "falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ self.precision
|
||||
+ ".mlir",
|
||||
self.falcon_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.falcon_mlir_path.exists():
|
||||
print(
|
||||
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
|
||||
)
|
||||
mlir_generated = True
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] generating MLIR locally")
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 100)
|
||||
)
|
||||
compilation_attention_mask = torch.ones(
|
||||
1, 100, dtype=torch.int64
|
||||
)
|
||||
falconCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
model = FalconModel(self.src_model)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
falconCompileInput,
|
||||
is_f16=self.precision in ["fp16", "int4"],
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
is_gptq=self.precision == "int4",
|
||||
)
|
||||
del model
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*falconCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
f_ = open(self.falcon_mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
self.falcon_vmfb_path.stem,
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
|
||||
def generate(self, prompt):
|
||||
model_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.max_padding_length,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model_inputs["prompt_text"] = prompt
|
||||
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
"do_sample": True,
|
||||
"top_k": 10,
|
||||
"num_return_sequences": 1,
|
||||
"eos_token_id": 11,
|
||||
}
|
||||
generate_kwargs["input_ids"] = input_ids
|
||||
generate_kwargs["attention_mask"] = attention_mask
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.src_model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
self.logits_processor = self.src_model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids.shape[-1],
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.src_model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.src_model._get_logits_warper(
|
||||
generation_config
|
||||
)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.src_model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id) if eos_token_id is not None else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
all_text = prompt
|
||||
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return all_text
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.src_model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
outputs = torch.from_numpy(
|
||||
self.shark_model(
|
||||
"forward",
|
||||
(model_inputs["input_ids"], model_inputs["attention_mask"]),
|
||||
)
|
||||
)
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
falcon_mlir_path = (
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir"
|
||||
)
|
||||
if args.falcon_mlir_path is None
|
||||
else Path(args.falcon_mlir_path)
|
||||
)
|
||||
falcon_vmfb_path = (
|
||||
Path(
|
||||
"falcon_"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ "_"
|
||||
+ args.device
|
||||
+ ".vmfb"
|
||||
)
|
||||
if args.falcon_vmfb_path is None
|
||||
else Path(args.falcon_vmfb_path)
|
||||
)
|
||||
|
||||
if args.precision == "int4":
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"TheBloke/falcon-"
|
||||
+ args.falcon_variant_to_use
|
||||
+ "-instruct-GPTQ"
|
||||
)
|
||||
else:
|
||||
if args.falcon_variant_to_use == "180b":
|
||||
hf_model_path_value = "tiiuae/falcon-180B-chat"
|
||||
else:
|
||||
hf_model_path_value = (
|
||||
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
|
||||
)
|
||||
|
||||
falcon = Falcon(
|
||||
model_name="falcon_" + args.falcon_variant_to_use,
|
||||
hf_model_path=hf_model_path_value,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
falcon_mlir_path=falcon_mlir_path,
|
||||
falcon_vmfb_path=falcon_vmfb_path,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("Falcon Model: ", falcon.model_name)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
|
||||
User: {prompt}
|
||||
Assistant:"""
|
||||
|
||||
res_str = falcon.generate(prompt_template)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
print(
|
||||
"\n\n-----\nHere's the complete formatted result: \n\n",
|
||||
res_str,
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
1449
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
1449
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
class BaseProcessor:
|
||||
def __init__(self):
|
||||
self.transform = lambda x: x
|
||||
return
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
def build(self, **kwargs):
|
||||
cfg = OmegaConf.create(kwargs)
|
||||
|
||||
return self.from_config(cfg)
|
||||
|
||||
|
||||
class BlipImageBaseProcessor(BaseProcessor):
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
if std is None:
|
||||
std = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
self.normalize = transforms.Normalize(mean, std)
|
||||
|
||||
|
||||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
|
||||
def __init__(self, image_size=224, mean=None, std=None):
|
||||
super().__init__(mean=mean, std=std)
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
image_size = cfg.get("image_size", 224)
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
@@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_align/
|
||||
@@ -0,0 +1,33 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
@@ -0,0 +1,25 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: 'prerained_minigpt4_7b.pth'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
629
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
629
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
import requests
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": (0.5, 0.5, 0.5),
|
||||
"std": (0.5, 0.5, 0.5),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([coords_h, coords_w])
|
||||
) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += (
|
||||
window_size[0] - 1
|
||||
) # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = (
|
||||
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1
|
||||
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (
|
||||
img_size[0] // patch_size[0]
|
||||
)
|
||||
self.patch_shape = (
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True,
|
||||
init_scale=0.001,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = (
|
||||
self.embed_dim
|
||||
) = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
||||
)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=self.patch_embed.patch_shape
|
||||
if use_rel_pos_bias
|
||||
else None,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes)
|
||||
if num_classes > 0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if "pos_embed" in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
|
||||
)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print(
|
||||
"Position interpolate from %dx%d to %dx%d"
|
||||
% (orig_size, orig_size, new_size, new_size)
|
||||
)
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model["pos_embed"] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
# l.weight.data = l.weight.data.half()
|
||||
l.weight.data = l.weight.data
|
||||
if l.bias is not None:
|
||||
# l.bias.data = l.bias.data.half()
|
||||
l.bias.data = l.bias.data
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(
|
||||
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
|
||||
):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
num_heads=1408 // 88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
|
||||
local_filename = "eva_vit_g.pth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("File downloaded successfully.")
|
||||
state_dict = torch.load(local_filename, map_location="cpu")
|
||||
interpolate_pos_embed(model, state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
||||
@@ -0,0 +1,4 @@
|
||||
<Img><ImageHere></Img> Describe this image in detail.
|
||||
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
||||
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
||||
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|
||||
187
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
187
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, torch_dtype=torch.float32
|
||||
)
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = (
|
||||
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
)
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(tmp_model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
for i in range(self.max_num_tokens):
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
48
apps/language_models/utils.py
Normal file
48
apps/language_models/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
)
|
||||
if padding:
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=mhlo for tf models
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def main():
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
@@ -58,11 +58,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -76,11 +73,12 @@ def main():
|
||||
args.inpaint_full_res_padding,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -89,7 +87,10 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -51,11 +51,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -74,11 +71,12 @@ def main():
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -87,7 +85,10 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
@@ -223,7 +223,8 @@ def lora_train(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, both must not be "
|
||||
"empty.",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
@@ -286,7 +287,7 @@ def lora_train(
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,10 @@ from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def load_mlir_module():
|
||||
if "upscaler" in args.hf_model_id:
|
||||
is_upscaler = True
|
||||
else:
|
||||
is_upscaler = False
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
@@ -27,6 +31,7 @@ def load_mlir_module():
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
is_upscaler=is_upscaler,
|
||||
use_tuned=False,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
return_mlir=True,
|
||||
|
||||
@@ -42,11 +42,8 @@ def main():
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
if current_batch > 0:
|
||||
seed = -1
|
||||
seed = utils.sanitize_seed(seed)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
@@ -56,11 +53,12 @@ def main():
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
@@ -69,7 +67,12 @@ def main():
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be default to False.
|
||||
# When the models get uploaded, it should be defaulted to False.
|
||||
args.import_mlir = True
|
||||
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
@@ -73,6 +73,7 @@ if __name__ == "__main__":
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
@@ -1,56 +1,13 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += collect_data_files('sentencepiece')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
( 'web/ui/css/*', 'ui/css' ),
|
||||
( 'web/ui/logos/*', 'logos' )
|
||||
]
|
||||
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=['.'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
@@ -72,11 +29,11 @@ exe = EXE(
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd',
|
||||
name='nodai_shark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
|
||||
@@ -29,6 +29,7 @@ datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
@@ -42,6 +43,7 @@ block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/main.py'],
|
||||
|
||||
87
apps/stable_diffusion/shark_studio_imports.py
Normal file
87
apps/stable_diffusion/shark_studio_imports.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
"./apps/language_models/langchain",
|
||||
"./apps/language_models/src/pipelines/minigpt4_utils",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("torch-mlir")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
("src/utils/resources/opt_flags.json", "resources"),
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
"minigpt4_utils/configs",
|
||||
),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
|
||||
"minigpt4_utils/prompts",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape.append(width * mul_val)
|
||||
elif "/" in shape[i]:
|
||||
import math
|
||||
|
||||
div_val = int(shape[i].split("/")[1])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(math.ceil(batch_size / div_val))
|
||||
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
|
||||
def check_compilation(model, model_name):
|
||||
if not model:
|
||||
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
|
||||
raise Exception(
|
||||
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(["wget", custom_weights, "-O", weights_path])
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
print(f'use_tuned? sharkify: {use_tuned}')
|
||||
print(f"use_tuned? sharkify: {use_tuned}")
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
@@ -163,14 +172,26 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
model_config = self.model_name
|
||||
if "vae" == model:
|
||||
if self.custom_vae != "":
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
model_config = model_config + get_path_stem(
|
||||
self.custom_vae
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
@@ -197,7 +218,11 @@ class SharkifyStableDiffusionModel:
|
||||
tensor = None
|
||||
if isinstance(shape, list):
|
||||
clean_shape = replace_shape_str(
|
||||
shape, self.max_len, self.width, self.height, self.batch_size
|
||||
shape,
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
if dtype == torch.int64:
|
||||
tensor = torch.randint(1, 3, tuple(clean_shape))
|
||||
@@ -209,10 +234,12 @@ class SharkifyStableDiffusionModel:
|
||||
sys.exit("shape isn't specified correctly.")
|
||||
input_map.append(tensor)
|
||||
return input_map
|
||||
|
||||
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
@@ -226,7 +253,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
vae_encode = VaeEncodeModel()
|
||||
inputs = tuple(self.inputs["vae_encode"])
|
||||
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
|
||||
is_f16 = (
|
||||
True
|
||||
if not self.is_upscaler and self.precision == "fp16"
|
||||
else False
|
||||
)
|
||||
shark_vae_encode, vae_encode_mlir = compile_through_fx(
|
||||
vae_encode,
|
||||
inputs,
|
||||
@@ -243,7 +274,13 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
base_vae=self.base_vae,
|
||||
custom_vae=self.custom_vae,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
@@ -279,7 +316,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
|
||||
is_f16 = (
|
||||
True
|
||||
if not self.is_upscaler and self.precision == "fp16"
|
||||
else False
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -300,10 +341,13 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self):
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -316,12 +360,43 @@ class SharkifyStableDiffusionModel:
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
|
||||
control2, control3, control4, control5, control6, control7,
|
||||
control8, control9, control10, control11, control12, control13,
|
||||
def forward(
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
guidance_scale,
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
control13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
]
|
||||
)
|
||||
mb_res_samples = control13
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
@@ -342,23 +417,51 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
|
||||
model_name = "stencil_unet"
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[:2]
|
||||
+ (torch.nn.functional.pad(inputs[2], pad),)
|
||||
+ inputs[3:]
|
||||
)
|
||||
model_name = "stencil_unet_512"
|
||||
input_mask = [
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self):
|
||||
def get_control_net(self, use_large=False):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
@@ -386,38 +489,67 @@ class SharkifyStableDiffusionModel:
|
||||
stencil_image = torch.cat(
|
||||
[stencil_image_input] * 2
|
||||
) # needs to be same as controlledUNET latents
|
||||
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
|
||||
(
|
||||
down_block_res_samples,
|
||||
mid_block_res_sample,
|
||||
) = self.cnet.forward(
|
||||
latents,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embedding,
|
||||
controlnet_cond=stencil_image,
|
||||
return_dict=False,
|
||||
)
|
||||
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
|
||||
return tuple(
|
||||
list(down_block_res_samples) + [mid_block_res_sample]
|
||||
)
|
||||
|
||||
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
scnet = StencilControlNetModel(
|
||||
low_cpu_mem_usage=self.low_cpu_mem_usage
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_adaptor"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_adaptor",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_cnet, cnet_mlir
|
||||
|
||||
def get_unet(self):
|
||||
def get_unet(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -426,17 +558,26 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
if(args.attention_slicing.isdigit()):
|
||||
self.unet.set_attention_slice(int(args.attention_slicing))
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
):
|
||||
if args.attention_slicing.isdigit():
|
||||
self.unet.set_attention_slice(
|
||||
int(args.attention_slicing)
|
||||
)
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
# TODO: Instead of flattening the `control` try to use the list.
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale,
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
guidance_scale,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
@@ -452,17 +593,33 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["unet512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["unet"]
|
||||
)
|
||||
input_mask = [True, True, True, False]
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -471,15 +628,17 @@ class SharkifyStableDiffusionModel:
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_upscaler(self):
|
||||
def get_unet_upscaler(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -502,17 +661,27 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
input_mask = [True, True, True, False]
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
@@ -520,7 +689,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -528,14 +702,19 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
update_lora_weight(
|
||||
self.text_encoder, use_lora, "text_encoder"
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
|
||||
save_dir = ""
|
||||
if self.debug:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["clip"]
|
||||
)
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
@@ -567,34 +746,47 @@ class SharkifyStableDiffusionModel:
|
||||
vae_checkpoint = None
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
if custom_vae.endswith(".ckpt"):
|
||||
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
|
||||
vae_checkpoint = torch.load(
|
||||
self.custom_vae, map_location="cpu"
|
||||
)
|
||||
else:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
vae_checkpoint = safetensors.torch.load_file(
|
||||
self.custom_vae, device="cpu"
|
||||
)
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
|
||||
try:
|
||||
vae_checkpoint = convert_original_vae(vae_checkpoint)
|
||||
finally:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
vae_dict = {
|
||||
k: v
|
||||
for k, v in vae_checkpoint.items()
|
||||
if k[0:4] != "loss" and k not in vae_ignore_keys
|
||||
}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model):
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler()
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
elif self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_unet,
|
||||
)
|
||||
|
||||
return get_unet()
|
||||
else:
|
||||
return self.get_unet()
|
||||
return self.get_unet(use_large=use_large)
|
||||
else:
|
||||
return self.get_controlled_unet()
|
||||
return self.get_controlled_unet(use_large=use_large)
|
||||
|
||||
def vae_encode(self):
|
||||
try:
|
||||
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
|
||||
self.inputs["vae_encode"] = self.get_input_info_for(
|
||||
base_models["vae_encode"]
|
||||
)
|
||||
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
|
||||
|
||||
check_compilation(compiled_vae_encode, "Vae Encode")
|
||||
@@ -616,25 +808,35 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self):
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
compiled_unet = None
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[self.base_model_id]
|
||||
)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[model_id]
|
||||
)
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
print(
|
||||
"Retrying with a different base model configuration"
|
||||
)
|
||||
continue
|
||||
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
@@ -657,7 +859,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def vae(self):
|
||||
try:
|
||||
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
|
||||
vae_input = (
|
||||
base_models["vae"]["vae_upscaler"]
|
||||
if self.is_upscaler
|
||||
else base_models["vae"]["vae"]
|
||||
)
|
||||
self.inputs["vae"] = self.get_input_info_for(vae_input)
|
||||
|
||||
is_base_vae = self.base_vae
|
||||
@@ -673,10 +879,14 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self):
|
||||
def controlnet(self, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
if self.return_mlir:
|
||||
|
||||
@@ -17,9 +17,13 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": [
|
||||
"stablediffusion",
|
||||
"inpaint_v2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
@@ -27,9 +31,12 @@ def get_quantize_model():
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
|
||||
sys.exit(
|
||||
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
|
||||
)
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -38,6 +43,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -74,13 +84,35 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre process image
|
||||
image = image.resize((width, height))
|
||||
# Pre-process image
|
||||
if resample_type == "Lanczos":
|
||||
resample_type = Image.LANCZOS
|
||||
elif resample_type == "Nearest Neighbor":
|
||||
resample_type = Image.NEAREST
|
||||
elif resample_type == "Bilinear":
|
||||
resample_type = Image.BILINEAR
|
||||
elif resample_type == "Bicubic":
|
||||
resample_type = Image.BICUBIC
|
||||
elif resample_type == "Adaptive":
|
||||
resample_type = Image.ADAPTIVE
|
||||
elif resample_type == "Antialias":
|
||||
resample_type = Image.ANTIALIAS
|
||||
elif resample_type == "Box":
|
||||
resample_type = Image.BOX
|
||||
elif resample_type == "Affine":
|
||||
resample_type = Image.AFFINE
|
||||
elif resample_type == "Cubic":
|
||||
resample_type = Image.CUBIC
|
||||
else: # Fallback to Lanczos
|
||||
resample_type = Image.LANCZOS
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
@@ -135,7 +167,9 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -156,7 +190,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
@@ -172,6 +209,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
|
||||
@@ -14,6 +14,11 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -37,6 +42,11 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -378,6 +388,7 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -408,7 +419,10 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -14,6 +14,11 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -38,6 +43,11 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -379,6 +389,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -409,7 +420,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -14,6 +14,12 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -38,6 +44,12 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -46,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
@@ -56,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -99,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -123,43 +149,82 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -179,7 +244,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -204,7 +271,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
@@ -230,7 +299,10 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -13,6 +13,10 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
@@ -34,6 +38,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -81,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -112,7 +121,10 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
|
||||
@@ -17,9 +17,14 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
@@ -65,6 +70,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
low_res_scheduler: Union[
|
||||
DDIMScheduler,
|
||||
@@ -76,6 +86,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -84,6 +98,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
@@ -164,7 +179,11 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
self.status = SD_STATE_IDLE
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
@@ -178,15 +197,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
@@ -210,8 +240,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -236,6 +270,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
@@ -257,7 +292,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts, neg_prompts, max_length
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
|
||||
@@ -15,6 +15,9 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
@@ -48,6 +51,10 @@ class StableDiffusionPipeline:
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
@@ -57,6 +64,7 @@ class StableDiffusionPipeline:
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
@@ -66,7 +74,8 @@ class StableDiffusionPipeline:
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.ondemand = ondemand
|
||||
# TODO: Find a better workaround for fetching base_model_id early enough for CLIPTokenizer.
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
self.tokenizer = get_tokenizer()
|
||||
except:
|
||||
@@ -81,13 +90,15 @@ class StableDiffusionPipeline:
|
||||
if self.import_mlir or self.use_lora:
|
||||
if not self.import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. Importing MLIR anyways."
|
||||
"Warning: LoRA provided but import_mlir not specified. "
|
||||
"Importing MLIR anyways."
|
||||
)
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
else:
|
||||
try:
|
||||
self.text_encoder = get_clip()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.text_encoder = self.sd_model.clip()
|
||||
|
||||
@@ -104,7 +115,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.unet = get_unet()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet = self.sd_model.unet()
|
||||
|
||||
@@ -112,6 +124,24 @@ class StableDiffusionPipeline:
|
||||
del self.unet
|
||||
self.unet = None
|
||||
|
||||
def load_unet_512(self):
|
||||
if self.unet_512 is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
else:
|
||||
try:
|
||||
self.unet_512 = get_unet(use_large=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.unet_512 = self.sd_model.unet(use_large=True)
|
||||
|
||||
def unload_unet_512(self):
|
||||
del self.unet_512
|
||||
self.unet_512 = None
|
||||
|
||||
def load_vae(self):
|
||||
if self.vae is not None:
|
||||
return
|
||||
@@ -121,7 +151,8 @@ class StableDiffusionPipeline:
|
||||
else:
|
||||
try:
|
||||
self.vae = get_vae()
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae = self.sd_model.vae()
|
||||
|
||||
@@ -200,7 +231,10 @@ class StableDiffusionPipeline:
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
@@ -219,16 +253,28 @@ class StableDiffusionPipeline:
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
@@ -251,6 +297,7 @@ class StableDiffusionPipeline:
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -272,6 +319,10 @@ class StableDiffusionPipeline:
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
import_mlir: bool,
|
||||
model_id: str,
|
||||
@@ -356,16 +407,21 @@ class StableDiffusionPipeline:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
The prompt or prompts not to guide the image generation.
|
||||
Ignored when not using guidance
|
||||
(i.e., ignored if `guidance_scale` is less than `1`).
|
||||
model_max_length (int):
|
||||
SHARK: pass the max length instead of relying on pipe.tokenizer.model_max_length
|
||||
SHARK: pass the max length instead of relying on
|
||||
pipe.tokenizer.model_max_length
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not,
|
||||
SHARK: must be set to True as we always expect neg embeddings (defaulted to True)
|
||||
SHARK: must be set to True as we always expect neg embeddings
|
||||
(defaulted to True)
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
SHARK: max_embeddings_multiples>1 produce a tensor shape error (defaulted to 1)
|
||||
The max multiple length of prompt embeddings compared to the
|
||||
max output length of text encoder.
|
||||
SHARK: max_embeddings_multiples>1 produce a tensor shape error
|
||||
(defaulted to 1)
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
SHARK: num_images_per_prompt is not used (defaulted to 1)
|
||||
@@ -384,9 +440,11 @@ class StableDiffusionPipeline:
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
f"`negative_prompt`: "
|
||||
f"{negative_prompt} has batch size {len(negative_prompt)}, "
|
||||
f"but `prompt`: {prompt} has batch size {batch_size}. "
|
||||
f"Please make sure that passed `negative_prompt` matches "
|
||||
"the batch size of `prompt`."
|
||||
)
|
||||
|
||||
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
||||
@@ -399,16 +457,43 @@ class StableDiffusionPipeline:
|
||||
)
|
||||
# SHARK: we are not using num_images_per_prompt
|
||||
# bs_embed, seq_len, _ = text_embeddings.shape
|
||||
# text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
# text_embeddings = text_embeddings.repeat(
|
||||
# 1,
|
||||
# num_images_per_prompt,
|
||||
# 1
|
||||
# )
|
||||
# text_embeddings = (
|
||||
# text_embeddings.view(
|
||||
# bs_embed * num_images_per_prompt,
|
||||
# seq_len,
|
||||
# -1
|
||||
# )
|
||||
# )
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# SHARK: we are not using num_images_per_prompt
|
||||
# bs_embed, seq_len, _ = uncond_embeddings.shape
|
||||
# uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
# uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
# uncond_embeddings = (
|
||||
# uncond_embeddings.repeat(
|
||||
# 1,
|
||||
# num_images_per_prompt,
|
||||
# 1
|
||||
# )
|
||||
# )
|
||||
# uncond_embeddings = (
|
||||
# uncond_embeddings.view(
|
||||
# bs_embed * num_images_per_prompt,
|
||||
# seq_len,
|
||||
# -1
|
||||
# )
|
||||
# )
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if text_embeddings.shape[1] > model_max_length:
|
||||
pad = (0, 0) * (len(text_embeddings.shape) - 2)
|
||||
pad = pad + (0, 512 - text_embeddings.shape[1])
|
||||
text_embeddings = torch.nn.functional.pad(text_embeddings, pad)
|
||||
|
||||
# SHARK: Report clip inference time
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
@@ -443,7 +528,8 @@ re_attention = re.compile(
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Parses a string with attention tokens and returns a list of pairs:
|
||||
text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
@@ -670,6 +756,12 @@ def get_unweighted_text_embeddings(
|
||||
return text_embeddings
|
||||
|
||||
|
||||
# This function deals with NoneType values occuring in tokens after padding
|
||||
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
|
||||
def filter_nonetype_tokens(tokens: List[List]):
|
||||
return [[49407 if token is None else token for token in tokens[0]]]
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -761,6 +853,10 @@ def get_weighted_text_embeddings(
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
prompt_tokens = filter_nonetype_tokens(prompt_tokens)
|
||||
|
||||
# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
|
||||
if uncond_prompt is not None:
|
||||
@@ -773,6 +869,10 @@ def get_weighted_text_embeddings(
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.model_max_length,
|
||||
)
|
||||
|
||||
# FIXME: This is a hacky fix caused by tokenizer padding with None values
|
||||
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
|
||||
|
||||
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
||||
uncond_tokens = torch.tensor(
|
||||
uncond_tokens, dtype=torch.long, device="cpu"
|
||||
|
||||
@@ -8,6 +8,9 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
@@ -38,9 +41,28 @@ def get_schedulers(model_id):
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistepKarras++"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
@@ -62,5 +84,21 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"KDPM2AncestralDiscrete"
|
||||
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
return schedulers
|
||||
|
||||
@@ -84,9 +84,6 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
|
||||
@@ -28,11 +28,16 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
parse_seed_input,
|
||||
batch_seeds,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,9 @@ from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
from shark.parser import shark_args
|
||||
|
||||
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"stablediffusion/untuned":"gs://shark_tank/nightly"
|
||||
},
|
||||
{
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan",
|
||||
"stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
"stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan",
|
||||
|
||||
@@ -5,4 +5,7 @@
|
||||
["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"],
|
||||
["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"],
|
||||
["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"],
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"]]
|
||||
["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"],
|
||||
["A photo of a beach, sunset, calm, beautiful landscape, waves, water"],
|
||||
["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"],
|
||||
["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]]
|
||||
|
||||
@@ -109,14 +109,14 @@ def load_lower_configs(base_model_id=None):
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
|
||||
else:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
if not spec or spec in ["sm_80"]:
|
||||
if (
|
||||
version in ["v2_1", "v2_1base"]
|
||||
and args.height == 768
|
||||
@@ -125,12 +125,42 @@ def load_lower_configs(base_model_id=None):
|
||||
config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
|
||||
elif spec in ["rdna3"] and version in [
|
||||
"v2_1",
|
||||
"v2_1base",
|
||||
"v1_4",
|
||||
"v1_5",
|
||||
]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.max_length}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}_"
|
||||
f"{args.width}x{args.height}.json"
|
||||
)
|
||||
elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}_"
|
||||
f"{args.width}x{args.height}.json"
|
||||
)
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
|
||||
config_name = (
|
||||
f"{args.annotation_model}_"
|
||||
f"{version}_"
|
||||
f"{args.precision}_"
|
||||
f"{device}_"
|
||||
f"{spec}.json"
|
||||
)
|
||||
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
full_gs_url = config_bucket + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
return lowering_config_dir
|
||||
|
||||
@@ -171,9 +201,22 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
if use_winograd:
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
)
|
||||
else:
|
||||
preprocess_flag = "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
|
||||
dump_module = ireec.compile_str(
|
||||
input_mlir,
|
||||
|
||||
@@ -19,48 +19,56 @@ p = argparse.ArgumentParser(
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
# Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="which app to use, one of: txt2img, img2img, outpaint, inpaint",
|
||||
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smokes coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
help="text you don't want to see in the generated image.",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
"blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
help="Text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting",
|
||||
help="Path to the image input for img2img/inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
type=str,
|
||||
default=-1,
|
||||
help="the seed to use. -1 for a random one.",
|
||||
help="The seed or list of seeds to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -68,7 +76,7 @@ p.add_argument(
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="the number of inferences to be made in a single `batch_count`.",
|
||||
help="The number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -76,7 +84,7 @@ p.add_argument(
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
help="the height of the output image.",
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -84,77 +92,137 @@ p.add_argument(
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
help="the width of the output image.",
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_level",
|
||||
type=int,
|
||||
default=20,
|
||||
help="the value to be used for noise level of upscaler.",
|
||||
help="The value to be used for noise level of upscaler.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="max length of the tokenizer output, options are 64 and 77.",
|
||||
help="Max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max "
|
||||
"output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="the strength of change applied on the given input image for img2img",
|
||||
help="The strength of change applied on the given input image for "
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Training Params
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model",
|
||||
help="Directory to save the lora fine tuned model.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt",
|
||||
help="Directory containing images that are an example of the prompt.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The no. of steps to train",
|
||||
help="The number of steps to train.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Inpainting and Outpainting Params
|
||||
# Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting",
|
||||
help="Path to the mask image input for inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If inpaint only masked area or whole picture",
|
||||
help="If inpaint only masked area or whole picture.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -162,7 +230,7 @@ p.add_argument(
|
||||
type=int,
|
||||
default=32,
|
||||
choices=range(0, 257, 4),
|
||||
help="Number of pixels for only masked padding",
|
||||
help="Number of pixels for only masked padding.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -170,7 +238,7 @@ p.add_argument(
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 257, 8),
|
||||
help="Number of expended pixels for one direction for outpainting",
|
||||
help="Number of expended pixels for one direction for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -178,89 +246,92 @@ p.add_argument(
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 65),
|
||||
help="Number of blur pixels for outpainting",
|
||||
help="Number of blur pixels for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend left for outpainting",
|
||||
help="If expend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend right for outpainting",
|
||||
help="If expend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend top for outpainting",
|
||||
help="If expend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If expend bottom for outpainting",
|
||||
help="If expend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) (min=0.0, max=4.0)",
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) "
|
||||
"(min=0.0, max=4.0).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0)",
|
||||
help="Color variation for outpainting (min=0.0, max=1.0).",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
# Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
"--device", type=str, default="vulkan", help="Device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
"--precision", type=str, default="fp16", help="Precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
help="Attempts to load the model from a precompiled flat-buffer "
|
||||
"and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
help="Saves the compiled flat-buffer to the local directory.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -274,28 +345,42 @@ p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="other supported schedulers are [PNDM, DDIM, LMSDiscrete, EulerDiscrete, DPMSolverMultistep]",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
|
||||
"HeunDiscrete].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="specify the format in which output image is save. Supported options: jpg / png",
|
||||
help="Specify the format in which output image is save. "
|
||||
"Supported options: jpg / png.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json",
|
||||
help="Directory path to save the output images and json.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of batch to be generated with random seeds in single execution",
|
||||
help="Number of batches to be generated with random seeds in "
|
||||
"single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--repeatable_seeds",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="The seed of the first batch will be used as the rng seed to "
|
||||
"generate the subsequent seeds for subsequent batches in that run.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -309,7 +394,8 @@ p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose Vae needs to be plugged in.",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
|
||||
"needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -323,14 +409,15 @@ p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption",
|
||||
help="Use the accelerate package to reduce cpu memory consumption.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
|
||||
"or an integer).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -343,193 +430,250 @@ p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint "
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
|
||||
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
|
||||
help="Runs the quantized version of stable diffusion model. "
|
||||
"This is currently in experimental phase. "
|
||||
"Currently, only runs the stable-diffusion-2-1-base model in "
|
||||
"int8 quantization.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ondemand",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Load and unload models for low VRAM",
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
"--iree_metal_target_platform",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for metal.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Misc. Debug and Optimization flags
|
||||
# Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the default scheduler precompiled into the model if available",
|
||||
help="Use the default scheduler precompiled into the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
help="Specify where to save downloaded shark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
help="When enabled call amdllpc to get ISA dumps. "
|
||||
"Use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
help="Dispatches to return benchmark data on. "
|
||||
'Use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
help="Directory where you want to store dispatch data "
|
||||
'generated with "--dispatch_benchmarks".',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
help="Flag for inserting debug frames between iterations "
|
||||
"for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
help="Flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="flag setting warmup count for clip and vae [>= 0].",
|
||||
help="Flag setting warmup count for CLIP and VAE [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
|
||||
help="Flag to clear all mlir and vmfb from common locations. "
|
||||
"Recompiling will take several minutes.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save a generation information json file with the image.",
|
||||
help="Flag for whether or not to save a generation information "
|
||||
"json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
|
||||
help="Flag for whether or not to save generation information in "
|
||||
"PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="if import_mlir is True, saves mlir via the debug option in shark importer. Does nothing if import_mlir is false (the default)",
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Web UI flags
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for removing the progress bar animation during image generation",
|
||||
help="Flag for removing the progress bar animation during "
|
||||
"image generation.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="one of: [api, app, web]",
|
||||
help="One of: [api, app, web].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
help="Flag for generating a public URL.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
help="Flag for setting server port.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for enabling rest API",
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the output gallery tab, and avoid exposing "
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether the output gallery tab in the UI should "
|
||||
"follow symlinks when listing subdirectories under --output_dir.",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
# SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file",
|
||||
help="Directory to save the annotated mlir file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -543,33 +687,43 @@ p.add_argument(
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save annotated mlir file",
|
||||
help="Save annotated mlir file.",
|
||||
)
|
||||
##############################################################################
|
||||
### SD model auto-tuner flags
|
||||
# SD model auto-tuner flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--tuned_config_dir",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the tuned config file",
|
||||
help="Directory to save the tuned config file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--num_iters",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Number of iterations for tuning",
|
||||
help="Number of iterations for tuning.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all",
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# DocuChat Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--run_docuchat_web",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
|
||||
@@ -8,17 +8,24 @@ from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -31,6 +38,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from omegaconf import OmegaConf
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
|
||||
def get_extended_name(model_name):
|
||||
@@ -47,6 +55,7 @@ def get_vmfb_path_name(model_name):
|
||||
def _load_vmfb(shark_module, vmfb_path, model, precision):
|
||||
model = "vae" if "base_vae" in model or "vae_encode" in model else model
|
||||
model = "unet" if "stencil" in model else model
|
||||
model = "unet" if "unet512" in model else model
|
||||
precision = "fp32" if "clip" in model else precision
|
||||
extra_args = get_opt_flags(model, precision)
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
@@ -69,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -78,12 +87,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
def get_shark_model(tank_url, model_name, extra_args=None):
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
if "cuda" in args.device:
|
||||
@@ -111,12 +121,15 @@ def compile_through_fx(
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=[],
|
||||
extra_args=None,
|
||||
base_model_id=None,
|
||||
model_name=None,
|
||||
precision=None,
|
||||
return_mlir=False,
|
||||
device=None,
|
||||
):
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
if not return_mlir and model_name is not None:
|
||||
vmfb_path = get_vmfb_path_name(extended_model_name)
|
||||
if os.path.isfile(vmfb_path):
|
||||
@@ -141,20 +154,31 @@ def compile_through_fx(
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
if use_tuned:
|
||||
if "vae" in extended_model_name.split("_")[0]:
|
||||
args.annotation_model = "vae"
|
||||
if "unet" in model_name.split("_")[0]:
|
||||
if (
|
||||
"unet" in model_name.split("_")[0]
|
||||
or "unet_512" in model_name.split("_")[0]
|
||||
):
|
||||
args.annotation_model = "unet"
|
||||
mlir_module = sd_model_annotation(
|
||||
mlir_module, extended_model_name, base_model_id
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_dir):
|
||||
save_dir = ""
|
||||
|
||||
mlir_module = save_mlir(
|
||||
mlir_module,
|
||||
model_name=extended_model_name,
|
||||
dir=save_dir,
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
device=args.device if device is None else device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
if generate_vmfb:
|
||||
@@ -163,20 +187,22 @@ def compile_through_fx(
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
del mlir_module
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if args.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
@@ -199,13 +225,15 @@ def get_device_mapping(driver, key_combination=3):
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
@@ -219,7 +247,7 @@ def get_device_mapping(driver, key_combination=3):
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
@@ -232,10 +260,12 @@ def get_device_mapping(driver, key_combination=3):
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
@@ -243,7 +273,8 @@ def map_device_to_name_path(device, key_combination=3):
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
@@ -266,10 +297,21 @@ def set_init_device_flags():
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "metal" in args.device:
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_metal_target_platform:
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{args.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
@@ -294,13 +336,24 @@ def set_init_device_flags():
|
||||
if (
|
||||
args.precision != "fp16"
|
||||
or args.height not in [512, 768]
|
||||
or (args.height == 512 and args.width != 512)
|
||||
or (args.height == 768 and args.width != 768)
|
||||
or (args.height == 512 and args.width not in [512, 768])
|
||||
or (args.height == 768 and args.width not in [512, 768])
|
||||
or args.batch_size != 1
|
||||
or ("vulkan" not in args.device and "cuda" not in args.device)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
args.height != args.width
|
||||
and "rdna2" in args.iree_vulkan_target_triple
|
||||
and base_model_id
|
||||
not in [
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif base_model_id not in [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
@@ -338,13 +391,26 @@ def set_init_device_flags():
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
]
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
or "rdna" not in args.iree_vulkan_target_triple
|
||||
)
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "rdna2" in args.iree_vulkan_target_triple and (
|
||||
base_model_id
|
||||
not in [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print(f"Using tuned models for {base_model_id}/fp16/{args.device}.")
|
||||
print(
|
||||
f"Using tuned models for {base_model_id}(fp16) on "
|
||||
f"device {args.device}."
|
||||
)
|
||||
else:
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
@@ -401,18 +467,45 @@ def get_available_devices():
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{device['name']} => {driver_name}://{i}")
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("device => cpu")
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
@@ -432,10 +525,15 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
@@ -489,17 +587,17 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
|
||||
from_safetensors = (
|
||||
True if custom_weights.lower().endswith(".safetensors") else False
|
||||
)
|
||||
# EMA weights usually yield higher quality images for inference but non-EMA weights have
|
||||
# been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
|
||||
# weight extraction or not.
|
||||
# EMA weights usually yield higher quality images for inference but
|
||||
# non-EMA weights have been yielding better results in our case.
|
||||
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
||||
# they want to go for EMA weight extraction or not.
|
||||
extract_ema = False
|
||||
print(
|
||||
"Loading diffusers' pipeline from original stable diffusion checkpoint"
|
||||
)
|
||||
num_in_channels = 9 if is_inpaint else 4
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=custom_weights,
|
||||
checkpoint_path_or_dict=custom_weights,
|
||||
extract_ema=extract_ema,
|
||||
from_safetensors=from_safetensors,
|
||||
num_in_channels=num_in_channels,
|
||||
@@ -513,7 +611,10 @@ def convert_original_vae(vae_checkpoint):
|
||||
for key in list(vae_checkpoint.keys()):
|
||||
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
||||
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_url = (
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
||||
"main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
||||
@@ -634,7 +735,7 @@ def update_lora_weight(model, use_lora, model_name):
|
||||
|
||||
|
||||
# `fetch_and_update_base_model_id` is a resource utility function which
|
||||
# helps maintaining mapping of the model to run with its base model.
|
||||
# helps to maintain mapping of the model to run with its base model.
|
||||
# If `base_model` is "", then this function tries to fetch the base model
|
||||
# info for the `model_to_run`.
|
||||
def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
@@ -651,14 +752,17 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
return base_model
|
||||
elif base_model == "":
|
||||
return base_model
|
||||
# Update JSON data to contain an entry mapping model_to_run with base_model.
|
||||
# Update JSON data to contain an entry mapping model_to_run with
|
||||
# base_model.
|
||||
json_data.update(data)
|
||||
with open(variants_path, "w", encoding="utf-8") as jsonFile:
|
||||
json.dump(json_data, jsonFile)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
@@ -666,6 +770,56 @@ def sanitize_seed(seed):
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(
|
||||
type(seed) is int for seed in seed_input
|
||||
):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
|
||||
# Generate a set of seeds from an input expression for batch_count batches,
|
||||
# optionally using that input as the rng seed for any randomly generated seeds.
|
||||
def batch_seeds(
|
||||
seed_input: str | list | int, batch_count: int, repeatable=False
|
||||
):
|
||||
# turn the input into a list if possible
|
||||
seeds = parse_seed_input(seed_input)
|
||||
|
||||
# slice or pad the list to be of batch_count length
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
seed_random(str(seeds))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
|
||||
if repeatable:
|
||||
# reset the rng back to normal
|
||||
random_setstate(saved_random_state)
|
||||
|
||||
return seeds
|
||||
|
||||
|
||||
# clear all the cached objects to recompile cleanly.
|
||||
def clear_all():
|
||||
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
|
||||
@@ -676,7 +830,8 @@ def clear_all():
|
||||
for vmfb in vmfbs:
|
||||
if os.path.exists(vmfb):
|
||||
os.remove(vmfb)
|
||||
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
|
||||
# Temporary workaround of deleting yaml files to incorporate
|
||||
# diffusers' pipeline.
|
||||
# TODO: Remove this once we have better weight updation logic.
|
||||
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
|
||||
for yaml in inference_yaml:
|
||||
@@ -692,26 +847,45 @@ def clear_all():
|
||||
elif os.name == "unix":
|
||||
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
|
||||
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
|
||||
if args.local_tank_cache != "":
|
||||
shutil.rmtree(args.local_tank_cache)
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
args.output_dir if args.output_dir else Path.cwd(), "generated_imgs"
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info={}):
|
||||
output_path = args.output_dir if args.output_dir else Path.cwd()
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
output_path, "generated_imgs", dt.now().strftime("%Y%m%d")
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = (
|
||||
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
|
||||
)
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if args.custom_vae:
|
||||
img_vae = Path(os.path.basename(args.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
@@ -722,17 +896,30 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
if args.write_metadata_to_png:
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
|
||||
f"{args.prompts[0]}"
|
||||
f"\nNegative prompt: {args.negative_prompts[0]}"
|
||||
f"\nSteps: {args.steps},"
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not supported yet."
|
||||
"Image saved as png instead. Supported formats: png / jpg"
|
||||
f"[ERROR] Format {args.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
@@ -746,12 +933,17 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
"WIDTH": args.width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
with open(csv_path, "a", encoding="utf-8") as csv_obj:
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
@@ -765,16 +957,27 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={args.height}x{args.width}, batch_count={args.batch_count}, batch_size={args.batch_size}, max_length={args.max_length}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, "
|
||||
f"guidance_scale={args.guidance_scale}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
)
|
||||
|
||||
return text_output
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size but we need to ensure that
|
||||
# it conforms with our model contraints :-
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
# it conforms with our model constraints :-
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
|
||||
51
apps/stable_diffusion/studio_bundle.spec
Normal file
51
apps/stable_diffusion/studio_bundle.spec
Normal file
@@ -0,0 +1,51 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web\\index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name='studio_bundle',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
name='studio_bundle',
|
||||
)
|
||||
@@ -1,12 +1,21 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import transformers
|
||||
import logging
|
||||
|
||||
if sys.platform == "darwin":
|
||||
# import before IREE to avoid torch-MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
import shutil
|
||||
import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
@@ -18,16 +27,23 @@ def launch_app(address):
|
||||
|
||||
window = Tk()
|
||||
|
||||
# getting screen width and height of display
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
if args.api or "api" in args.ui.split(","):
|
||||
@@ -36,7 +52,10 @@ if __name__ == "__main__":
|
||||
img2img_api,
|
||||
upscaler_api,
|
||||
inpaint_api,
|
||||
outpaint_api,
|
||||
llm_chat_api,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
import uvicorn
|
||||
|
||||
@@ -47,23 +66,36 @@ if __name__ == "__main__":
|
||||
app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/sdapi/v1/outpaint", outpaint_api, methods=["post"]
|
||||
# )
|
||||
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
|
||||
# chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
app.add_api_route(
|
||||
"/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
app.add_api_route(
|
||||
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
)
|
||||
app.include_router(APIRouter())
|
||||
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
sys.exit(0)
|
||||
|
||||
import gradio as gr
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
clear_gradio_tmp_imgs_folder,
|
||||
config_gradio_tmp_imgs_folder,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
clear_gradio_tmp_imgs_folder()
|
||||
# Create custom models folders if they don't exist
|
||||
create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -80,15 +112,20 @@ if __name__ == "__main__":
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
@@ -97,6 +134,7 @@ if __name__ == "__main__":
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
@@ -105,6 +143,7 @@ if __name__ == "__main__":
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
@@ -113,11 +152,13 @@ if __name__ == "__main__":
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -125,6 +166,16 @@ if __name__ == "__main__":
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
minigpt4_web,
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
# init global sd pipeline and config
|
||||
@@ -151,10 +202,29 @@ if __name__ == "__main__":
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
with gr.TabItem(label="Text-to-Image", id=0):
|
||||
txt2img_web.render()
|
||||
with gr.TabItem(label="Image-to-Image", id=1):
|
||||
@@ -165,13 +235,39 @@ if __name__ == "__main__":
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
with gr.TabItem(label="Model Manager", id=5):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
@@ -268,6 +364,37 @@ if __name__ == "__main__":
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
if args.output_gallery:
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_txt2img,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_png_info_img, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_img2img,
|
||||
1,
|
||||
[outputgallery_filename],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_inpaint,
|
||||
2,
|
||||
[outputgallery_filename],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_outpaint,
|
||||
3,
|
||||
[outputgallery_filename],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
outputgallery_sendto_upscaler,
|
||||
4,
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
|
||||
@@ -5,6 +5,8 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
txt2img_gallery,
|
||||
txt2img_png_info_img,
|
||||
txt2img_status,
|
||||
txt2img_sendto_img2img,
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
@@ -18,6 +20,7 @@ from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_hf_model_id,
|
||||
img2img_gallery,
|
||||
img2img_init_image,
|
||||
img2img_status,
|
||||
img2img_sendto_inpaint,
|
||||
img2img_sendto_outpaint,
|
||||
img2img_sendto_upscaler,
|
||||
@@ -30,6 +33,7 @@ from apps.stable_diffusion.web.ui.inpaint_ui import (
|
||||
inpaint_hf_model_id,
|
||||
inpaint_gallery,
|
||||
inpaint_init_image,
|
||||
inpaint_status,
|
||||
inpaint_sendto_img2img,
|
||||
inpaint_sendto_outpaint,
|
||||
inpaint_sendto_upscaler,
|
||||
@@ -42,6 +46,7 @@ from apps.stable_diffusion.web.ui.outpaint_ui import (
|
||||
outpaint_hf_model_id,
|
||||
outpaint_gallery,
|
||||
outpaint_init_image,
|
||||
outpaint_status,
|
||||
outpaint_sendto_img2img,
|
||||
outpaint_sendto_inpaint,
|
||||
outpaint_sendto_upscaler,
|
||||
@@ -54,6 +59,7 @@ from apps.stable_diffusion.web.ui.upscaler_ui import (
|
||||
upscaler_hf_model_id,
|
||||
upscaler_gallery,
|
||||
upscaler_init_image,
|
||||
upscaler_status,
|
||||
upscaler_sendto_img2img,
|
||||
upscaler_sendto_inpaint,
|
||||
upscaler_sendto_outpaint,
|
||||
@@ -68,4 +74,20 @@ from apps.stable_diffusion.web.ui.model_manager import (
|
||||
modelmanager_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import (
|
||||
stablelm_chat,
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.generate_config import model_config_web
|
||||
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
)
|
||||
|
||||
@@ -117,16 +117,12 @@ body {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: var(--size-5) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
color: transparent;
|
||||
background-color: transparent;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
@@ -227,6 +223,66 @@ footer {
|
||||
}
|
||||
|
||||
/* Hide the download icon from the nod logo */
|
||||
#top_logo .download {
|
||||
#top_logo button {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* workarounds for container=false not currently working for dropdowns */
|
||||
.dropdown_no_container {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
#output_subdir_container :first-child {
|
||||
border: none;
|
||||
}
|
||||
|
||||
/* reduced animation load when generating */
|
||||
.generating {
|
||||
animation-play-state: paused !important;
|
||||
}
|
||||
|
||||
/* better clarity when progress bars are minimal */
|
||||
.meta-text {
|
||||
background-color: var(--block-label-background-fill);
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
}
|
||||
|
||||
.output_icon_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
min-width: 7em !important;
|
||||
}
|
||||
|
||||
/* output gallery should take up most of the viewport height regardless of image size/number */
|
||||
#outputgallery_gallery .fixed-height {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
|
||||
41
apps/stable_diffusion/web/ui/generate_config.py
Normal file
41
apps/stable_diffusion/web/ui/generate_config.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
|
||||
def get_model_config():
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
return c.split_into_layers()
|
||||
|
||||
|
||||
with gr.Blocks() as model_config_web:
|
||||
with gr.Row():
|
||||
hf_models = gr.Dropdown(
|
||||
label="Model List",
|
||||
choices=["Vicuna"],
|
||||
value="Vicuna",
|
||||
visible=True,
|
||||
)
|
||||
get_model_config_btn = gr.Button(value="Get Model Config")
|
||||
json_view = gr.JSON()
|
||||
|
||||
get_model_config_btn.click(
|
||||
fn=get_model_config,
|
||||
inputs=[],
|
||||
outputs=[json_view],
|
||||
)
|
||||
367
apps/stable_diffusion/web/ui/h2ogpt.py
Normal file
367
apps/stable_diffusion/web/ui/h2ogpt.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
from apps.language_models.langchain.enums import (
|
||||
DocumentChoices,
|
||||
LangChainAction,
|
||||
)
|
||||
import apps.language_models.langchain.gen as gen
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
create_or_update_db,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
h2ogpt_model = 0
|
||||
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = """
|
||||
SHARK DocuChat
|
||||
Chat with an AI, contextualized with provided files.
|
||||
"""
|
||||
|
||||
|
||||
def create_prompt(history):
|
||||
system_message = start_message
|
||||
for item in history:
|
||||
print("His item: ", item)
|
||||
|
||||
conversation = "<|endoftext|>".join(
|
||||
[
|
||||
"<|endoftext|><|answer|>".join([item[0], item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def chat(curr_system_message, history, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global h2ogpt_model
|
||||
global sharkModel
|
||||
global h2ogpt_tokenizer
|
||||
global model_state
|
||||
global langchain
|
||||
global userpath_selector
|
||||
from apps.language_models.langchain.h2oai_pipeline import generate_token
|
||||
|
||||
if h2ogpt_model == 0:
|
||||
if "cuda" in device:
|
||||
shark_device = "cuda"
|
||||
elif "sync" in device:
|
||||
shark_device = "cpu"
|
||||
elif "task" in device:
|
||||
shark_device = "cpu"
|
||||
elif "vulkan" in device:
|
||||
shark_device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
device = "cpu" if shark_device == "cpu" else "cuda"
|
||||
|
||||
args.device = shark_device
|
||||
args.precision = precision
|
||||
|
||||
from apps.language_models.langchain.gen import Langchain
|
||||
|
||||
langchain = Langchain(device, precision)
|
||||
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
|
||||
load_4bit=True
|
||||
if device == "cuda"
|
||||
else False, # load model in 4bit if device is cuda to save memory
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
infer_devices=True,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
inference_server="",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
gpu_id=0,
|
||||
reward_type=None,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder=None,
|
||||
compile_model=False,
|
||||
verbose=False,
|
||||
)
|
||||
model_state = dict(
|
||||
model=h2ogpt_model,
|
||||
tokenizer=h2ogpt_tokenizer,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
from apps.language_models.langchain.h2oai_pipeline import (
|
||||
H2OGPTSHARKModel,
|
||||
)
|
||||
|
||||
sharkModel = H2OGPTSHARKModel()
|
||||
|
||||
prompt = create_prompt(history)
|
||||
output_dict = langchain.evaluate(
|
||||
model_state=model_state,
|
||||
my_db_state=None,
|
||||
instruction=prompt,
|
||||
iinput="",
|
||||
context="",
|
||||
stream_output=True,
|
||||
prompt_type="prompt_answer",
|
||||
prompt_dict={
|
||||
"promptA": "",
|
||||
"promptB": "",
|
||||
"PreInstruct": "<|prompt|>",
|
||||
"PreInput": None,
|
||||
"PreResponse": "<|answer|>",
|
||||
"terminate_response": [
|
||||
"<|prompt|>",
|
||||
"<|answer|>",
|
||||
"<|endoftext|>",
|
||||
],
|
||||
"chat_sep": "<|endoftext|>",
|
||||
"chat_turn_sep": "<|endoftext|>",
|
||||
"humanstr": "<|prompt|>",
|
||||
"botstr": "<|answer|>",
|
||||
"generates_leading_space": False,
|
||||
},
|
||||
temperature=0.1,
|
||||
top_p=0.75,
|
||||
top_k=40,
|
||||
num_beams=1,
|
||||
max_new_tokens=256,
|
||||
min_new_tokens=0,
|
||||
early_stopping=False,
|
||||
max_time=180,
|
||||
repetition_penalty=1.07,
|
||||
num_return_sequences=1,
|
||||
do_sample=False,
|
||||
chat=True,
|
||||
instruction_nochat=prompt,
|
||||
iinput_nochat="",
|
||||
langchain_mode="UserData",
|
||||
langchain_action=LangChainAction.QUERY.value,
|
||||
top_k_docs=3,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
document_choice=[DocumentChoices.All_Relevant.name],
|
||||
concurrency_count=1,
|
||||
memory_restriction_level=2,
|
||||
raise_generate_gpu_exceptions=False,
|
||||
chat_context="",
|
||||
use_openai_embedding=False,
|
||||
use_openai_model=False,
|
||||
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
db_type="chroma",
|
||||
n_jobs=-1,
|
||||
first_para=False,
|
||||
max_max_time=60 * 2,
|
||||
model_state0=model_state,
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
|
||||
output = generate_token(sharkModel, **output_dict)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
return history
|
||||
|
||||
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
with gr.Row():
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
print(supported_devices)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
|
||||
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
|
||||
import pathlib
|
||||
|
||||
upload_path = None
|
||||
database = None
|
||||
database_directory = os.path.abspath(
|
||||
"apps/language_models/langchain/db_path/"
|
||||
)
|
||||
|
||||
def read_path():
|
||||
global upload_path
|
||||
filenames = [
|
||||
[f]
|
||||
for f in os.listdir(upload_path)
|
||||
if os.path.isfile(os.path.join(upload_path, f))
|
||||
]
|
||||
filenames.sort()
|
||||
return filenames
|
||||
|
||||
def upload_file(f):
|
||||
names = []
|
||||
for tmpfile in f:
|
||||
name = tmpfile.name.split("/")[-1]
|
||||
basename = os.path.join(upload_path, name)
|
||||
with open(basename, "wb") as w:
|
||||
with open(tmpfile.name, "rb") as r:
|
||||
w.write(r.read())
|
||||
update_or_create_db()
|
||||
return read_path()
|
||||
|
||||
def update_userpath(newpath):
|
||||
global upload_path
|
||||
upload_path = newpath
|
||||
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
|
||||
return read_path()
|
||||
|
||||
def update_or_create_db():
|
||||
global database
|
||||
global upload_path
|
||||
|
||||
sources = path_to_docs(
|
||||
upload_path,
|
||||
verbose=True,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
url=None,
|
||||
enable_captions=False,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
)
|
||||
|
||||
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
database = create_or_update_db(
|
||||
"chroma",
|
||||
database_directory,
|
||||
"UserData",
|
||||
sources,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
def first_run():
|
||||
global database
|
||||
if database is None:
|
||||
update_or_create_db()
|
||||
|
||||
update_userpath(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
)
|
||||
h2ogpt_upload.load(fn=first_run)
|
||||
h2ogpt_web.load(fn=first_run)
|
||||
|
||||
with gr.Column():
|
||||
text = gr.DataFrame(
|
||||
col_count=(1, "fixed"),
|
||||
type="array",
|
||||
label="Documents",
|
||||
value=read_path(),
|
||||
)
|
||||
with gr.Row():
|
||||
upload = gr.UploadButton(
|
||||
label="Upload documents",
|
||||
file_count="multiple",
|
||||
)
|
||||
upload.upload(fn=upload_file, inputs=upload, outputs=text)
|
||||
userpath_selector.render()
|
||||
userpath_selector.input(
|
||||
fn=update_userpath, inputs=userpath_selector, outputs=text
|
||||
).then(fn=update_or_create_db)
|
||||
@@ -1,10 +1,9 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
@@ -26,10 +25,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -49,7 +51,7 @@ def img2img_inf(
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -65,6 +67,8 @@ def img2img_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -103,7 +107,8 @@ def img2img_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -131,7 +136,8 @@ def img2img_inf(
|
||||
image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
f"scheduler"
|
||||
)
|
||||
args.scheduler = "EulerDiscrete"
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
@@ -226,13 +232,14 @@ def img2img_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -240,30 +247,39 @@ def img2img_inf(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
ceil(steps / strength),
|
||||
strength,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed, extra_info)
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
seeds[current_batch],
|
||||
extra_info,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
# yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -300,7 +316,9 @@ def img2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = img2img_inf(
|
||||
@@ -332,7 +350,13 @@ def img2img_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
resample_type="Lanczos",
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -350,13 +374,21 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_model_info = (str(get_custom_model_path())).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
i2i_model_info = f"Custom Model Path: {i2i_model_info}"
|
||||
img2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=i2i_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -364,43 +396,56 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
img2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
i2i_vae_info = (str(get_custom_model_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
i2i_vae_info = f"VAE Path: {i2i_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
label=f"Custom VAE Models",
|
||||
info=i2i_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
).style(height=300)
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -409,6 +454,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
@@ -463,15 +509,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -483,6 +539,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -502,15 +559,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
width = gr.Slider(
|
||||
384, 768, value=args.width, step=8, label="Width"
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=args.max_length,
|
||||
@@ -521,21 +569,48 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=args.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
@@ -554,6 +629,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -563,28 +643,29 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -592,17 +673,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
columns=2,
|
||||
object_fit="contain",
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
@@ -639,14 +720,24 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -26,7 +25,11 @@ from apps.stable_diffusion.src import (
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -46,7 +49,7 @@ def inpaint_inf(
|
||||
inpaint_full_res_padding: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -61,6 +64,7 @@ def inpaint_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: int,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -89,7 +93,8 @@ def inpaint_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -176,14 +181,15 @@ def inpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -196,24 +202,28 @@ def inpaint_inf(
|
||||
inpaint_full_res_padding,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Inpaint", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
|
||||
@@ -252,7 +262,9 @@ def inpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["image"])
|
||||
mask = decode_base64_to_image(InputData["mask"])
|
||||
@@ -273,7 +285,7 @@ def inpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -283,7 +295,12 @@ def inpaint_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -301,13 +318,23 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_model_info = (
|
||||
f"Custom Model Path: {inpaint_model_info}"
|
||||
)
|
||||
inpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=inpaint_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -317,34 +344,46 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
inpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
inpaint_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_vae_info = f"VAE Path: {inpaint_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
label=f"Custom VAE Models",
|
||||
info=inpaint_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
@@ -353,19 +392,30 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
).style(height=350)
|
||||
height=350,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=inpaint_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -377,6 +427,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -455,6 +506,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -464,28 +520,29 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -493,17 +550,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
@@ -540,14 +598,22 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[inpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
outputs=[inpaint_gallery, std_output, inpaint_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=inpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import lora_train
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.src import prompt_examples, args, utils
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -24,15 +24,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
train_lora_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
train_lora_model_info = (
|
||||
f"Custom Model Path: {train_lora_model_info}"
|
||||
)
|
||||
custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=train_lora_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -40,25 +50,38 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
|
||||
placeholder="Select 'None' in the Models "
|
||||
"dropdown on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
train_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
train_lora_info = f"LoRA Path: {train_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights to initialize weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA weights to initialize weights",
|
||||
info=train_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use a "
|
||||
"standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID to initialize weights",
|
||||
lines=3,
|
||||
@@ -74,7 +97,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
@@ -84,6 +107,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
@@ -147,22 +171,25 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
value=utils.parse_seed_input(args.seed)[0],
|
||||
precision=0,
|
||||
label="Seed",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
train_lora = gr.Button("Train LoRA")
|
||||
@@ -215,7 +242,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
),
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress=args.progress_bar,
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
|
||||
194
apps/stable_diffusion/web/ui/minigpt4_ui.py
Normal file
194
apps/stable_diffusion/web/ui/minigpt4_ui.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# ========================================
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
import gradio as gr
|
||||
|
||||
# from apps.language_models.src.pipelines.minigpt4_pipeline import (
|
||||
# # MiniGPT4,
|
||||
# CONV_VISION,
|
||||
# )
|
||||
from pathlib import Path
|
||||
|
||||
chat = None
|
||||
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return (
|
||||
None,
|
||||
gr.update(value=None, interactive=True),
|
||||
gr.update(
|
||||
placeholder="Please upload your image first", interactive=False
|
||||
),
|
||||
gr.update(value="Upload & Start Chat", interactive=True),
|
||||
chat_state,
|
||||
img_list,
|
||||
)
|
||||
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
|
||||
global chat
|
||||
if chat is None:
|
||||
from apps.language_models.src.pipelines.minigpt4_pipeline import (
|
||||
MiniGPT4,
|
||||
CONV_VISION,
|
||||
)
|
||||
|
||||
vision_model_precision = precision
|
||||
if precision in ["int4", "int8"]:
|
||||
vision_model_precision = "fp16"
|
||||
vision_model_vmfb_path = Path(
|
||||
f"vision_model_{vision_model_precision}_{device}.vmfb"
|
||||
)
|
||||
qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb")
|
||||
chat = MiniGPT4(
|
||||
model_name="MiniGPT4",
|
||||
hf_model_path=None,
|
||||
max_new_tokens=30,
|
||||
device=device,
|
||||
precision=precision,
|
||||
_compile=_compile,
|
||||
vision_model_vmfb_path=vision_model_vmfb_path,
|
||||
qformer_vmfb_path=qformer_vmfb_path,
|
||||
)
|
||||
if gr_img is None:
|
||||
return None, None, gr.update(interactive=True), chat_state, None
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return (
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=True, placeholder="Type and press Enter"),
|
||||
gr.update(value="Start Chatting", interactive=False),
|
||||
chat_state,
|
||||
img_list,
|
||||
)
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
if len(user_message) == 0:
|
||||
return (
|
||||
gr.update(
|
||||
interactive=True, placeholder="Input should not be empty!"
|
||||
),
|
||||
chatbot,
|
||||
chat_state,
|
||||
)
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return "", chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature,
|
||||
max_new_tokens=300,
|
||||
max_length=2000,
|
||||
)[0]
|
||||
print(llm_message)
|
||||
print("************")
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
title = """<h1 align="center">MultiModal SHARK (experimental)</h1>"""
|
||||
description = """<h3>Upload your images and start chatting!</h3>"""
|
||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||
"""
|
||||
|
||||
# TODO show examples below
|
||||
|
||||
with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
interactive=True,
|
||||
variant="primary",
|
||||
)
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
num_beams = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=1,
|
||||
step=1,
|
||||
interactive=True,
|
||||
label="beam search numbers)",
|
||||
)
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0.1,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value="cuda",
|
||||
# if enabled
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(label="MiniGPT-4")
|
||||
text_input = gr.Textbox(
|
||||
label="User",
|
||||
placeholder="Please upload your image first",
|
||||
interactive=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
choices=[
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
_compile = gr.Checkbox(
|
||||
value=False,
|
||||
label="Compile",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
upload_button.click(
|
||||
upload_img,
|
||||
[image, text_input, chat_state, device, precision, _compile],
|
||||
[image, text_input, upload_button, chat_state, img_list],
|
||||
)
|
||||
|
||||
text_input.submit(
|
||||
gradio_ask,
|
||||
[text_input, chatbot, chat_state],
|
||||
[text_input, chatbot, chat_state],
|
||||
).then(
|
||||
gradio_answer,
|
||||
[chatbot, chat_state, img_list, num_beams, temperature],
|
||||
[chatbot, chat_state, img_list],
|
||||
)
|
||||
clear.click(
|
||||
gradio_reset,
|
||||
[chat_state, img_list],
|
||||
[chatbot, image, text_input, upload_button, chat_state, img_list],
|
||||
queue=False,
|
||||
)
|
||||
@@ -19,7 +19,10 @@ def get_hf_list(num_of_models=20):
|
||||
|
||||
|
||||
def get_civit_list(num_of_models=50):
|
||||
path = f"https://civitai.com/api/v1/models?limit={num_of_models}&types=Checkpoint"
|
||||
path = (
|
||||
f"https://civitai.com/api/v1/models?limit="
|
||||
f"{num_of_models}&types=Checkpoint"
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
raw_json = requests.get(path, headers=headers).json()
|
||||
models = list(raw_json.items())[0][1]
|
||||
@@ -79,7 +82,7 @@ with gr.Blocks() as model_web:
|
||||
type="value",
|
||||
label="Model Source",
|
||||
)
|
||||
model_numebr = gr.Slider(
|
||||
model_number = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=10,
|
||||
@@ -95,6 +98,7 @@ with gr.Blocks() as model_web:
|
||||
choices=None,
|
||||
value=None,
|
||||
visible=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
# TODO: select and SendTo
|
||||
civit_models = gr.Gallery(
|
||||
@@ -111,9 +115,9 @@ with gr.Blocks() as model_web:
|
||||
modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint")
|
||||
modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler")
|
||||
|
||||
def get_model_list(model_source, model_numebr):
|
||||
def get_model_list(model_source, model_number):
|
||||
if model_source == "Hugging Face":
|
||||
hf_model_list = get_hf_list(model_numebr)
|
||||
hf_model_list = get_hf_list(model_number)
|
||||
models = []
|
||||
for model in hf_model_list:
|
||||
# TODO: add model info
|
||||
@@ -124,7 +128,7 @@ with gr.Blocks() as model_web:
|
||||
gr.Row.update(visible=True),
|
||||
)
|
||||
elif model_source == "Civitai":
|
||||
civit_model_list = get_civit_list(model_numebr)
|
||||
civit_model_list = get_civit_list(model_number)
|
||||
models = []
|
||||
for model in civit_model_list:
|
||||
image = get_image_from_model(model)
|
||||
@@ -148,7 +152,7 @@ with gr.Blocks() as model_web:
|
||||
|
||||
get_model_btn.click(
|
||||
fn=get_model_list,
|
||||
inputs=[model_source, model_numebr],
|
||||
inputs=[model_source, model_number],
|
||||
outputs=[
|
||||
hf_models,
|
||||
civit_models,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -23,11 +21,13 @@ from apps.stable_diffusion.src import (
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -49,7 +49,7 @@ def outpaint_inf(
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -64,6 +64,7 @@ def outpaint_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -91,7 +92,8 @@ def outpaint_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -176,8 +178,10 @@ def outpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
@@ -185,9 +189,7 @@ def outpaint_inf(
|
||||
bottom = True if "down" in directions else False
|
||||
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -205,26 +207,30 @@ def outpaint_inf(
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Outpaint", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -261,7 +267,9 @@ def outpaint_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = outpaint_inf(
|
||||
@@ -284,7 +292,7 @@ def outpaint_api(
|
||||
custom_model="None",
|
||||
hf_model_id=InputData["hf_model_id"]
|
||||
if "hf_model_id" in InputData.keys()
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
else "stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_vae="None",
|
||||
precision="fp16",
|
||||
device=available_devices[0],
|
||||
@@ -294,7 +302,12 @@ def outpaint_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -312,13 +325,23 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_model_info = (
|
||||
f"Custom Model Path: {outpaint_model_info}"
|
||||
)
|
||||
outpaint_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=outpaint_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -328,52 +351,76 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
custom_checkpoint_type="inpainting"
|
||||
)
|
||||
+ predefined_paint_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
outpaint_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: ghunkins/stable-diffusion-liberty-inpainting, https://civitai.com/api/download/models/3433",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: ghunkins/stable-diffusion-liberty-inpainting, "
|
||||
"https://civitai.com/api/download/models/3433",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
outpaint_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_vae_info = f"VAE Path: {outpaint_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
label=f"Custom VAE Models",
|
||||
info=outpaint_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
outpaint_init_image = gr.Image(
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=outpaint_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -385,6 +432,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -485,6 +533,12 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -494,28 +548,29 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -523,17 +578,17 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -571,14 +626,22 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[outpaint_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
outputs=[outpaint_gallery, std_output, outpaint_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=outpaint_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
|
||||
492
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
492
apps/stable_diffusion/web/ui/outputgallery_ui.py
Normal file
@@ -0,0 +1,492 @@
|
||||
import glob
|
||||
import gradio as gr
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import nodlogo_loc
|
||||
from apps.stable_diffusion.web.utils.metadata import displayable_metadata
|
||||
|
||||
# -- Functions for file, directory and image info querying
|
||||
|
||||
output_dir = get_generated_imgs_path()
|
||||
|
||||
|
||||
def outputgallery_filenames(subdir) -> list[str]:
|
||||
new_dir_path = os.path.join(output_dir, subdir)
|
||||
if os.path.exists(new_dir_path):
|
||||
filenames = [
|
||||
glob.glob(new_dir_path + "/" + ext)
|
||||
for ext in ("*.png", "*.jpg", "*.jpeg")
|
||||
]
|
||||
|
||||
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def output_subdirs() -> list[str]:
|
||||
# Gets a list of subdirectories of output_dir and below, as relative paths.
|
||||
relative_paths = [
|
||||
os.path.relpath(entry[0], output_dir)
|
||||
for entry in os.walk(
|
||||
output_dir, followlinks=args.output_gallery_followlinks
|
||||
)
|
||||
]
|
||||
|
||||
# It is less confusing to always including the subdir that will take any
|
||||
# images generated today even if it doesn't exist yet
|
||||
if get_generated_imgs_todays_subdir() not in relative_paths:
|
||||
relative_paths.append(get_generated_imgs_todays_subdir())
|
||||
|
||||
# sort subdirectories so that the date named ones we probably
|
||||
# created in this or previous sessions come first, sorted with the most
|
||||
# recent first. Other subdirs are listed after.
|
||||
generated_paths = sorted(
|
||||
[path for path in relative_paths if path.isnumeric()], reverse=True
|
||||
)
|
||||
result_paths = generated_paths + sorted(
|
||||
[
|
||||
path
|
||||
for path in relative_paths
|
||||
if (not path.isnumeric()) and path != "."
|
||||
]
|
||||
)
|
||||
|
||||
return result_paths
|
||||
|
||||
|
||||
# --- Define UI layout for Gradio
|
||||
|
||||
with gr.Blocks() as outputgallery_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
with gr.Row(elem_id="outputgallery_gallery"):
|
||||
# needed to workaround gradio issue:
|
||||
# https://github.com/gradio-app/gradio/issues/2907
|
||||
dev_null = gr.Textbox("", visible=False)
|
||||
|
||||
gallery_files = gr.State(value=[])
|
||||
subdirectory_paths = gr.State(value=[])
|
||||
|
||||
with gr.Column(scale=6):
|
||||
logo = gr.Image(
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=2,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column(
|
||||
scale=15,
|
||||
min_width=160,
|
||||
elem_id="output_subdir_container",
|
||||
):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
type="value",
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
min_width=32,
|
||||
elem_classes="output_icon_button",
|
||||
):
|
||||
open_subdir = gr.Button(
|
||||
variant="secondary",
|
||||
value="\U0001F5C1", # unicode open folder
|
||||
interactive=False,
|
||||
size="sm",
|
||||
)
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
min_width=32,
|
||||
elem_classes="output_icon_button",
|
||||
):
|
||||
refresh = gr.Button(
|
||||
variant="secondary",
|
||||
value="\u21BB", # unicode clockwise arrow circle
|
||||
size="sm",
|
||||
)
|
||||
|
||||
image_columns = gr.Slider(
|
||||
label="Columns shown", value=4, minimum=1, maximum=16, step=1
|
||||
)
|
||||
outputgallery_filename = gr.Textbox(
|
||||
label="Filename",
|
||||
value="None",
|
||||
interactive=False,
|
||||
show_copy_button=True,
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
label="Parameter Information", open=False
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
with gr.Row():
|
||||
outputgallery_sendto_txt2img = gr.Button(
|
||||
value="Txt2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_inpaint = gr.Button(
|
||||
value="Inpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_outpaint = gr.Button(
|
||||
value="Outpaint",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_upscaler = gr.Button(
|
||||
value="Upscaler",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
# --- Event handlers
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_open_subdir(subdir):
|
||||
subdir_path = os.path.normpath(os.path.join(output_dir, subdir))
|
||||
|
||||
if os.path.isdir(subdir_path):
|
||||
if sys.platform == "linux":
|
||||
subprocess.run(["xdg-open", subdir_path])
|
||||
elif sys.platform == "darwin":
|
||||
subprocess.run(["open", subdir_path])
|
||||
elif sys.platform == "win32":
|
||||
os.startfile(subdir_path)
|
||||
|
||||
def on_refresh(current_subdir: str) -> list:
|
||||
# get an up-to-date subdirectory list
|
||||
refreshed_subdirs = output_subdirs()
|
||||
# get the images using either the current subdirectory or the most
|
||||
# recent valid one
|
||||
new_subdir = (
|
||||
current_subdir
|
||||
if current_subdir in refreshed_subdirs
|
||||
else refreshed_subdirs[0]
|
||||
)
|
||||
new_images = outputgallery_filenames(new_subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, new_subdir)}"
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
|
||||
def on_new_image(subdir, subdir_paths, status) -> list:
|
||||
# prevent error triggered when an image generates before the tab
|
||||
# has even been selected
|
||||
subdir_paths = (
|
||||
subdir_paths
|
||||
if len(subdir_paths) > 0
|
||||
else [get_generated_imgs_todays_subdir()]
|
||||
)
|
||||
|
||||
# only update if the current subdir is the most recent one as
|
||||
# new images only go there
|
||||
if subdir_paths[0] == subdir:
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
new_label = (
|
||||
f"{len(new_images)} images in "
|
||||
f"{os.path.join(output_dir, subdir)} - {status}"
|
||||
)
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
# otherwise change nothing,
|
||||
# (only untyped gradio gr.update() does this)
|
||||
return [gr.update(), gr.update(), gr.update()]
|
||||
|
||||
def on_select_image(images: list[str], evt: gr.SelectData) -> list:
|
||||
# evt.index is an index into the full list of filenames for
|
||||
# the current subdirectory
|
||||
filename = images[evt.index]
|
||||
params = displayable_metadata(filename)
|
||||
|
||||
if params:
|
||||
if params["source"] == "missing":
|
||||
return [
|
||||
"Could not find this image file, refresh the gallery and update the images",
|
||||
[["Status", "File missing"]],
|
||||
]
|
||||
else:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
exists = filename != "None" and os.path.exists(filename)
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
# to populate the subdirectory select box and the images from the most
|
||||
# recent subdirectory.
|
||||
#
|
||||
# We do it at this point rather than setting this up in the controls'
|
||||
# definitions as when you refresh the browser you always get what was
|
||||
# *initially* set, which won't include any new subdirectories or images
|
||||
# that might have created since the application was started. Doing it
|
||||
# this way means a browser refresh/reload always gets the most
|
||||
# up-to-date data.
|
||||
def on_select_tab(subdir_paths, request: gr.Request):
|
||||
local_client = request.headers["host"].startswith(
|
||||
"127.0.0.1:"
|
||||
) or request.headers["host"].startswith("localhost:")
|
||||
|
||||
if len(subdir_paths) == 0:
|
||||
return on_refresh("") + [gr.update(interactive=local_client)]
|
||||
else:
|
||||
return (
|
||||
# Change nothing, (only untyped gr.update() does this)
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
|
||||
# support things set with .style, nor the elem_classes kwarg, so we have
|
||||
# to directly set things up via JavaScript if we want the client to take
|
||||
# notice of our changes to the number of columns after it decides to put
|
||||
# them back to the original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# Many actions reset the number of columns shown in the gallery on the
|
||||
# browser end, so we have to set them back to what we think they should
|
||||
# be after the initial action.
|
||||
#
|
||||
# None of the actions on this tab trigger inference, and we want the
|
||||
# user to be able to do them whilst other tabs have ongoing inference
|
||||
# running. Waiting in the queue behind inference jobs would mean the UI
|
||||
# can't fully respond until the inference tasks complete,
|
||||
# hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if
|
||||
# I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real
|
||||
# timeout length for the number of columns to actually be applied.
|
||||
# Not really sure why, maybe something has to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
# replacement phase.
|
||||
clear_gallery = dict(
|
||||
fn=on_clear_gallery,
|
||||
inputs=None,
|
||||
outputs=[gallery, logo],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
open_subdir.click(
|
||||
on_open_subdir, inputs=[subdirectories], queue=False
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
outputgallery_sendto_upscaler,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# We should have been given the .select function for our tab, so set it up
|
||||
def outputgallery_tab_select(select):
|
||||
select(
|
||||
fn=on_select_tab,
|
||||
inputs=[subdirectory_paths],
|
||||
outputs=[
|
||||
subdirectories,
|
||||
subdirectory_paths,
|
||||
gallery_files,
|
||||
gallery,
|
||||
logo,
|
||||
open_subdir,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
@@ -1,26 +1,14 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
import os
|
||||
from apps.language_models.scripts.stablelm import (
|
||||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
get_tokenizer,
|
||||
StableLMModel,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
TextIteratorStreamer,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -28,183 +16,491 @@ def user(message, history):
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
input_ids = torch.randint(3, (1, 256))
|
||||
attention_mask = torch.randint(3, (1, 256))
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
vicuna_model = 0
|
||||
|
||||
|
||||
start_message_vicuna = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
}
|
||||
|
||||
def chat(curr_system_message, history, model):
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
get_sharded_model,
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
sharded_model = get_sharded_model()
|
||||
messages = curr_system_message + "".join(
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
prompt = messages.strip()
|
||||
print("prompt = ", prompt)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
new_sentence = ""
|
||||
for _ in range(200):
|
||||
original_input_ids = input_ids
|
||||
input_id_len = len(input_ids)
|
||||
pad_len = SAMPLE_INPUT_LEN - input_id_len
|
||||
attention_mask = torch.ones([1, input_id_len], dtype=torch.int64)
|
||||
input_ids = torch.tensor(input_ids)
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
torch.tensor(attention_mask),
|
||||
(0, pad_len),
|
||||
mode="constant",
|
||||
value=0,
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def set_vicuna_model(model):
|
||||
global vicuna_model
|
||||
vicuna_model = model
|
||||
|
||||
|
||||
def get_default_config():
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if _ == 0:
|
||||
output = sharded_model.forward(input_ids, is_first=True)
|
||||
else:
|
||||
output = sharded_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=False
|
||||
)
|
||||
logits = output["logits"]
|
||||
past_key_values = output["past_key_values"]
|
||||
new_word = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=1))
|
||||
if new_word == "</s>":
|
||||
break
|
||||
new_sentence += " " + new_word
|
||||
history[-1][1] = new_sentence
|
||||
yield history
|
||||
next_token = torch.argmax(logits[:, input_id_len - 1, :], dim=1)
|
||||
original_input_ids.append(next_token)
|
||||
input_ids = [next_token]
|
||||
print(new_sentence)
|
||||
return history
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
global sharkModel
|
||||
print("In chat")
|
||||
if sharkModel == 0:
|
||||
tok = get_tokenizer()
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
)
|
||||
stableLMModel = StableLMModel(m)
|
||||
sharkModel = compile_stableLM(
|
||||
stableLMModel,
|
||||
tuple([input_ids, attention_mask]),
|
||||
"stableLM_linalg_f32_seqLen256",
|
||||
os.getcwd(),
|
||||
)
|
||||
# Initialize a StopOnTokens object
|
||||
stop = StopOnTokens()
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
if len(curr_system_message.split()) > 160:
|
||||
print("clearing context")
|
||||
curr_system_message = start_message
|
||||
messages = curr_system_message + "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
# print(messages)
|
||||
# Tokenize the messages string
|
||||
streamer = TextIteratorStreamer(
|
||||
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = dict(
|
||||
new_text=messages,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
top_k=1000,
|
||||
temperature=1.0,
|
||||
num_beams=1,
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
sharkStableLM=sharkModel,
|
||||
)
|
||||
words_list = generate(**generate_kwargs)
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
# print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
||||
yield history
|
||||
return words_list
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
# print(f"model : {InputData['model']}")
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
# For Debugging input data from API
|
||||
# if is_chat_completion_api:
|
||||
# print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
# print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
# else:
|
||||
# print(f"prompt : {InputData['prompt']}")
|
||||
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
|
||||
global vicuna_model
|
||||
model_name = (
|
||||
InputData["model"] if "model" in InputData.keys() else "codegen"
|
||||
)
|
||||
model_path = model_map[model_name]
|
||||
device = "cpu-task"
|
||||
precision = "fp16"
|
||||
max_toks = (
|
||||
None
|
||||
if "max_tokens" not in InputData.keys()
|
||||
else InputData["max_tokens"]
|
||||
)
|
||||
if max_toks is None:
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# make it working for codegen first
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = create_prompt(
|
||||
model_name, [(InputData["messages"][0]["content"], "")]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
res = vicuna_model.generate(prompt)
|
||||
res_op = None
|
||||
for op in res:
|
||||
res_op = op
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion"
|
||||
if is_chat_completion_api
|
||||
else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value="TheBloke/vicuna-7B-1.1-HF",
|
||||
choices=[
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"TheBloke/vicuna-7B-1.1-HF",
|
||||
],
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
device_value = None
|
||||
for d in available_devices:
|
||||
if "vulkan" in d:
|
||||
device_value = d
|
||||
break
|
||||
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=device_value if device_value else available_devices[0],
|
||||
interactive=False,
|
||||
choices=available_devices,
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
chatbot = gr.Chatbot().style(height=500)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
).style(container=False)
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
stop = gr.Button("Stop")
|
||||
clear = gr.Button("Clear")
|
||||
system_msg = gr.Textbox(
|
||||
start_message, label="System Message", interactive=False, visible=False
|
||||
)
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model],
|
||||
outputs=[chatbot],
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
@@ -17,7 +17,8 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.png_metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
@@ -26,11 +27,16 @@ from apps.stable_diffusion.src import (
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
@@ -42,7 +48,7 @@ def txt2img_inf(
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -57,6 +63,12 @@ def txt2img_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
hiresfix_height: int,
|
||||
hiresfix_width: int,
|
||||
hiresfix_strength: float,
|
||||
resample_type: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -83,7 +95,8 @@ def txt2img_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -134,6 +147,7 @@ def txt2img_inf(
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
args.img_path = None
|
||||
@@ -171,12 +185,13 @@ def txt2img_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
text_output = ""
|
||||
for i in range(batch_count):
|
||||
if i > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -185,26 +200,105 @@ def txt2img_inf(
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
seeds.append(img_seed)
|
||||
# TODO: allow user to save original image
|
||||
# TODO: add option to let user keep both pipelines loaded, and unload
|
||||
# either at will
|
||||
# TODO: add custom step value slider
|
||||
# TODO: add option to use secondary model for the img2img pass
|
||||
if use_hiresfix is True:
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
1,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil="None",
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
global_obj.set_sd_obj(
|
||||
Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
1,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(args.scheduler)
|
||||
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
out_imgs[0],
|
||||
batch_size,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
ceil(steps / hiresfix_strength),
|
||||
hiresfix_strength,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil="None",
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(seeds, device)
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], img_seed)
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
return generated_imgs, text_output
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def encode_pil_to_base64(images):
|
||||
@@ -230,7 +324,9 @@ def txt2img_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}.'
|
||||
)
|
||||
res = txt2img_inf(
|
||||
InputData["prompt"],
|
||||
@@ -256,7 +352,17 @@ def txt2img_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=False,
|
||||
hiresfix_height=512,
|
||||
hiresfix_width=512,
|
||||
hiresfix_strength=0.6,
|
||||
resample_type="Nearest Neighbor",
|
||||
)
|
||||
|
||||
# Convert Generator to Subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -274,15 +380,25 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
t2i_model_info = (
|
||||
f"Custom Model Path: {t2i_model_info}"
|
||||
)
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=t2i_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -290,25 +406,35 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
choices=["None"]
|
||||
+ get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
txt2img_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
placeholder="Select 'None' in the dropdown "
|
||||
"on the left and enter model ID here.",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL.",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
t2i_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_vae_info = f"VAE Path: {t2i_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
label=f"VAE Models",
|
||||
info=t2i_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
png_info_img = gr.Image(
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
@@ -320,26 +446,36 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -351,8 +487,9 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
@@ -397,21 +534,67 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
@@ -431,29 +614,23 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
@@ -468,17 +645,29 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -513,23 +702,37 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
resample_type,
|
||||
],
|
||||
outputs=[txt2img_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
outputs=[txt2img_gallery, std_output, txt2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
png_info_img.change(
|
||||
txt2img_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -540,9 +743,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
png_info_img,
|
||||
txt2img_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
@@ -553,5 +759,8 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
txt2img_hf_model_id,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -17,16 +15,16 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.stable_diffusion.src.utils import get_generated_imgs_path
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
@@ -44,7 +42,7 @@ def upscaler_inf(
|
||||
steps: int,
|
||||
noise_level: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -59,6 +57,7 @@ def upscaler_inf(
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -66,6 +65,9 @@ def upscaler_inf(
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
@@ -87,7 +89,8 @@ def upscaler_inf(
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be empty",
|
||||
"Please provide either custom model or huggingface model ID, "
|
||||
"both must not be empty.",
|
||||
)
|
||||
if "civitai" in hf_model_id:
|
||||
args.ckpt_loc = hf_model_id
|
||||
@@ -174,12 +177,13 @@ def upscaler_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
img_seed = utils.sanitize_seed(seed)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
if current_batch > 0:
|
||||
img_seed = utils.sanitize_seed(-1)
|
||||
low_res_img = image
|
||||
high_res_img = Image.new("RGB", (height * 4, width * 4))
|
||||
|
||||
@@ -196,31 +200,56 @@ def upscaler_inf(
|
||||
steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
img_seed,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
high_res_img.paste(upscaled_image[0], (j * 4, i * 4))
|
||||
|
||||
save_output_img(high_res_img, img_seed, extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
seeds.append(img_seed)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, global_obj.get_sd_obj().log
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={device}"
|
||||
text_output += f"\nsteps={steps}, noise_level={noise_level}, guidance_scale={guidance_scale}, seed={seeds}"
|
||||
text_output += f"\nsize={height}x{width}, batch_count={batch_count}, batch_size={batch_size}, max_length={args.max_length}"
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={steps}, "
|
||||
f"noise_level={noise_level}, "
|
||||
f"guidance_scale={guidance_scale}, "
|
||||
f"seed={seeds[:current_batch + 1]}"
|
||||
)
|
||||
text_output += (
|
||||
f"\ninput size={height}x{width}, "
|
||||
f"output size={height*4}x{width*4}, "
|
||||
f"batch_count={batch_count}, "
|
||||
f"batch_size={batch_size}, "
|
||||
f"max_length={args.max_length}\n"
|
||||
)
|
||||
|
||||
yield generated_imgs, text_output
|
||||
text_output += global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(high_res_img, seeds[current_batch], extra_info)
|
||||
generated_imgs.append(high_res_img)
|
||||
global_obj.get_sd_obj().log += "\n"
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Upscaler", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
|
||||
yield generated_imgs, text_output, ""
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
@@ -257,7 +286,9 @@ def upscaler_api(
|
||||
InputData: dict,
|
||||
):
|
||||
print(
|
||||
f'Prompt: {InputData["prompt"]}, Negative Prompt: {InputData["negative_prompt"]}, Seed: {InputData["seed"]}'
|
||||
f'Prompt: {InputData["prompt"]}, '
|
||||
f'Negative Prompt: {InputData["negative_prompt"]}, '
|
||||
f'Seed: {InputData["seed"]}'
|
||||
)
|
||||
init_image = decode_base64_to_image(InputData["init_images"][0])
|
||||
res = upscaler_inf(
|
||||
@@ -286,7 +317,11 @@ def upscaler_api(
|
||||
lora_weights="None",
|
||||
lora_hf_id="",
|
||||
ondemand=False,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
# Converts generator type to subscriptable
|
||||
res = next(res)
|
||||
|
||||
return {
|
||||
"images": encode_pil_to_base64(res[0]),
|
||||
"parameters": {},
|
||||
@@ -304,13 +339,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=50)
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_model_info = (
|
||||
str(get_custom_model_path())
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_model_info = (
|
||||
f"Custom Model Path: {upscaler_model_info}"
|
||||
)
|
||||
upscaler_custom_model = gr.Dropdown(
|
||||
label=f"Models (Custom Model path: {get_custom_model_path()})",
|
||||
label=f"Models",
|
||||
info=upscaler_model_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
@@ -320,52 +365,76 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
custom_checkpoint_type="upscaler"
|
||||
)
|
||||
+ predefined_upscaler_models,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
upscaler_hf_model_id = gr.Textbox(
|
||||
elem_id="hf_model_id",
|
||||
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236",
|
||||
placeholder="Select 'None' in the Models dropdown "
|
||||
"on the left and enter model ID here "
|
||||
"e.g: SG161222/Realistic_Vision_V1.3, "
|
||||
"https://civitai.com/api/download/models/15236",
|
||||
value="",
|
||||
label="HuggingFace Model ID or Civitai model download URL",
|
||||
label="HuggingFace Model ID or Civitai model "
|
||||
"download URL",
|
||||
lines=3,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
upscaler_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_vae_info = f"VAE Path: {upscaler_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom Vae Models (Path: {get_custom_model_path('vae')})",
|
||||
label=f"Custom VAE Models",
|
||||
info=upscaler_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=1,
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
upscaler_init_image = gr.Image(
|
||||
label="Input Image", type="pil"
|
||||
).style(height=300)
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standlone LoRA weights (Path: {get_custom_model_path('lora')})",
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=upscaler_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standlone LoRA weights dropdown on the left if you want to use a standalone HuggingFace model ID for LoRA here e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
@@ -377,6 +446,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Scheduler",
|
||||
value="DDIM",
|
||||
choices=scheduler_list_cpu_only,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
@@ -456,6 +526,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
@@ -465,28 +540,29 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
None,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
_js="() => -1",
|
||||
)
|
||||
with gr.Column(scale=6):
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -494,17 +570,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
).style(columns=[2], object_fit="contain")
|
||||
output_dir = (
|
||||
args.output_dir if args.output_dir else Path.cwd()
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
)
|
||||
output_dir = Path(output_dir, "generated_imgs")
|
||||
std_output = gr.Textbox(
|
||||
value=f"Images will be saved at {output_dir}",
|
||||
value=f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
@@ -538,14 +615,23 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[upscaler_gallery, std_output],
|
||||
show_progress=args.progress_bar,
|
||||
outputs=[upscaler_gallery, std_output, upscaler_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=upscaler_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**kwargs)
|
||||
generate_click = stable_diffusion.click(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=None, cancels=[prompt_submit, neg_prompt_submit, generate_click]
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ class Config:
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
ondemand: str
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
@@ -39,8 +39,16 @@ scheduler_list_cpu_only = [
|
||||
"LMSDiscrete",
|
||||
"KDPM2Discrete",
|
||||
"DPMSolverMultistep",
|
||||
"DPMSolverMultistep++",
|
||||
"DPMSolverMultistepKarras",
|
||||
"DPMSolverMultistepKarras++",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"DEISMultistep",
|
||||
"KDPM2AncestralDiscrete",
|
||||
"DPMSolverSinglestep",
|
||||
"DDPM",
|
||||
"HeunDiscrete",
|
||||
]
|
||||
scheduler_list = scheduler_list_cpu_only + [
|
||||
"SharkEulerDiscrete",
|
||||
@@ -50,6 +58,7 @@ predefined_models = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
"wavymulder/Analog-Diffusion",
|
||||
"xzuyn/PhotoMerge",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
@@ -58,6 +67,7 @@ predefined_models = [
|
||||
predefined_paint_models = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
"xzuyn/PhotoMerge-inpainting",
|
||||
]
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
@@ -79,7 +89,8 @@ def create_custom_models_folders():
|
||||
else:
|
||||
if not os.path.isdir(args.ckpt_dir):
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, {args.ckpt_dir} folder does not exists."
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{args.ckpt_dir} folder does not exists."
|
||||
)
|
||||
for root in dir:
|
||||
get_custom_model_path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
9
apps/stable_diffusion/web/utils/common_label_calc.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# functions for generating labels used in common by tabs across the UI
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
else:
|
||||
return f"{tab_name} complete"
|
||||
@@ -1,31 +1,54 @@
|
||||
import os
|
||||
import tempfile
|
||||
import gradio
|
||||
from os import listdir
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
gradio_tmp_imgs_folder = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
# Clear all gradio tmp images
|
||||
def clear_gradio_tmp_imgs_folder():
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
return
|
||||
for fileName in listdir(gradio_tmp_imgs_folder):
|
||||
# Delete tmp png files
|
||||
if fileName.startswith("tmp") and fileName.endswith(".png"):
|
||||
os.remove(gradio_tmp_imgs_folder + fileName)
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio")
|
||||
|
||||
# Overwrite save_pil_to_file from gradio to save tmp images generated by gradio into our own tmp folder
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
if not os.path.exists(gradio_tmp_imgs_folder):
|
||||
os.mkdir(gradio_tmp_imgs_folder)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=gradio_tmp_imgs_folder
|
||||
print(
|
||||
f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. "
|
||||
+ "You may change this by setting the GRADIO_TEMP_DIR environment variable."
|
||||
)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
# Clear all gradio tmp images from the last session
|
||||
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing gradio UI temporary image files from a prior run. This may take some time..."
|
||||
)
|
||||
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True)
|
||||
print(
|
||||
f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
# Register save_pil_to_file override
|
||||
gradio.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
# older SHARK versions had to workaround gradio bugs and stored things differently
|
||||
else:
|
||||
image_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.startswith("tmp")
|
||||
and filename.endswith(".png")
|
||||
]
|
||||
if len(image_files) > 0:
|
||||
print(
|
||||
"Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..."
|
||||
)
|
||||
cleanup_start = time()
|
||||
for filename in image_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user