mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-23 09:28:01 -05:00
Compare commits
196 Commits
speed-up-h
...
v4.0.0rc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
330e1354b4 | ||
|
|
21621eebf0 | ||
|
|
c24f2046e7 | ||
|
|
297408d67e | ||
|
|
0131e7d928 | ||
|
|
06ff105a1f | ||
|
|
bb8e6bbee6 | ||
|
|
328dc99f3a | ||
|
|
ef55077e84 | ||
|
|
ba3d8af161 | ||
|
|
b07b7af710 | ||
|
|
19d66d5ec7 | ||
|
|
ed20255abf | ||
|
|
fed1f983db | ||
|
|
a386544a1d | ||
|
|
0851de9090 | ||
|
|
1bd8e33f8c | ||
|
|
e3f29ed320 | ||
|
|
3fd824306c | ||
|
|
2584a950aa | ||
|
|
1adaf63253 | ||
|
|
b9f1a4bd65 | ||
|
|
731942dbed | ||
|
|
4117cea5bf | ||
|
|
21617f3bc1 | ||
|
|
9fcd67b5c0 | ||
|
|
a4be935458 | ||
|
|
eb6e6548ed | ||
|
|
8287fcf097 | ||
|
|
dd475e28ed | ||
|
|
24e741e2d1 | ||
|
|
e0bf9ce5c6 | ||
|
|
c66e8b395e | ||
|
|
4c417adc82 | ||
|
|
437a413ca3 | ||
|
|
4492bedd19 | ||
|
|
db12ce95a8 | ||
|
|
ee3a1a95ef | ||
|
|
4bb5aba70e | ||
|
|
cd55c23713 | ||
|
|
1d2743af1b | ||
|
|
99d2099ccd | ||
|
|
b64a693f16 | ||
|
|
9d523a3094 | ||
|
|
af660163ca | ||
|
|
7e4b462fca | ||
|
|
4468dd6948 | ||
|
|
4f39e248dd | ||
|
|
44b3e5d43f | ||
|
|
8894a9e48a | ||
|
|
c73f58e486 | ||
|
|
614fece147 | ||
|
|
8ef8082d65 | ||
|
|
d93d4afbb7 | ||
|
|
01207a2fa5 | ||
|
|
d0800c4888 | ||
|
|
2a300ecada | ||
|
|
90340a39c7 | ||
|
|
ee77abb4fe | ||
|
|
004bca5c42 | ||
|
|
5ad048a161 | ||
|
|
6369ccd05e | ||
|
|
3a5314f1ca | ||
|
|
4c0896e436 | ||
|
|
f7cd3cf1f4 | ||
|
|
efea1a8a7d | ||
|
|
d0d695c020 | ||
|
|
2a648da557 | ||
|
|
54f1a1f952 | ||
|
|
8d2a4db902 | ||
|
|
7b393656de | ||
|
|
43948e0758 | ||
|
|
cc03fcbcb6 | ||
|
|
d1e445fa49 | ||
|
|
adba8489f2 | ||
|
|
d919022ba5 | ||
|
|
e076898798 | ||
|
|
9f19b766a4 | ||
|
|
4688623711 | ||
|
|
be951da99d | ||
|
|
9ee2e7ff25 | ||
|
|
149ff758b9 | ||
|
|
65d415d5aa | ||
|
|
c74c1927ec | ||
|
|
c454ccc65c | ||
|
|
46fd3465ce | ||
|
|
97afa6e2a6 | ||
|
|
96730107d1 | ||
|
|
6a9dede66f | ||
|
|
8c2ff794d5 | ||
|
|
145bb45858 | ||
|
|
9376b13435 | ||
|
|
eec82afd89 | ||
|
|
c47dbf7258 | ||
|
|
92b2e8186a | ||
|
|
70a88c6b99 | ||
|
|
56e7c04475 | ||
|
|
bd5b43c00d | ||
|
|
631e789195 | ||
|
|
133c90e116 | ||
|
|
4433b78e59 | ||
|
|
daeb766468 | ||
|
|
92b0d13d0e | ||
|
|
67d26cd633 | ||
|
|
9e28317a12 | ||
|
|
5b51ebf1c4 | ||
|
|
59228643a9 | ||
|
|
b24657df11 | ||
|
|
d4686b7f64 | ||
|
|
67163c2224 | ||
|
|
f01e81d382 | ||
|
|
a50e0a4802 | ||
|
|
df0a5aa92a | ||
|
|
0bd9a0a9ea | ||
|
|
4ae2cd242e | ||
|
|
0696094d95 | ||
|
|
fb1ae55010 | ||
|
|
deb1d4eb14 | ||
|
|
d156fd2093 | ||
|
|
c41e87160a | ||
|
|
eba1fc1355 | ||
|
|
96702c395e | ||
|
|
3361aec065 | ||
|
|
8ba4b2a150 | ||
|
|
df12e12e09 | ||
|
|
ee38fbe89c | ||
|
|
6e2cef1db5 | ||
|
|
b1f5ac4548 | ||
|
|
e52274ecac | ||
|
|
66f0ff5b13 | ||
|
|
cab5b64f0b | ||
|
|
a42812d78d | ||
|
|
281222df3c | ||
|
|
d5674150fa | ||
|
|
0cb2cf6644 | ||
|
|
da87266c9c | ||
|
|
35731a6f51 | ||
|
|
a3dfa161a8 | ||
|
|
42d606f07c | ||
|
|
9063b1ae61 | ||
|
|
6aae88bd88 | ||
|
|
57c1954da7 | ||
|
|
2410ed689a | ||
|
|
a10dccdd43 | ||
|
|
a3570901f7 | ||
|
|
fd457955bc | ||
|
|
1f69613f5d | ||
|
|
7a87ebb3b2 | ||
|
|
4ee4a801c6 | ||
|
|
53b7f6be37 | ||
|
|
dbd7c94e7c | ||
|
|
50bb9a6b41 | ||
|
|
13bb3c5e15 | ||
|
|
80c2a4b925 | ||
|
|
8ce485b036 | ||
|
|
6fc3e86061 | ||
|
|
33ded359e6 | ||
|
|
effbd8a1ba | ||
|
|
ddde355b09 | ||
|
|
fe2c6f621a | ||
|
|
d0fcdbe8a3 | ||
|
|
a28547b3dd | ||
|
|
c7b2bdb846 | ||
|
|
4a20377fef | ||
|
|
ed803640f7 | ||
|
|
576bb4a61d | ||
|
|
b6065d6328 | ||
|
|
04229f4a21 | ||
|
|
73a190fb6e | ||
|
|
952d97741e | ||
|
|
afd08c5f46 | ||
|
|
d1f859a446 | ||
|
|
5118160282 | ||
|
|
8e694992bb | ||
|
|
4077dfe0c3 | ||
|
|
fe8e391aad | ||
|
|
ac8f606d99 | ||
|
|
0aa2070ce0 | ||
|
|
ff66779aa3 | ||
|
|
2ca65ab9fa | ||
|
|
b34624a2a8 | ||
|
|
b8aa9752f1 | ||
|
|
1b5d8eb9e7 | ||
|
|
773182f425 | ||
|
|
6386109fc5 | ||
|
|
c008704bc8 | ||
|
|
a3a42d25d3 | ||
|
|
8959d1bf51 | ||
|
|
8fd9342712 | ||
|
|
f0b815aa9b | ||
|
|
3a5b0b819c | ||
|
|
bbcbcd9b63 | ||
|
|
fdecb886b2 | ||
|
|
2f0a653a7f | ||
|
|
b0add805c5 | ||
|
|
ed4e8624dd |
28
.github/workflows/frontend-checks.yml
vendored
28
.github/workflows/frontend-checks.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# Runs frontend code quality checks.
|
||||
#
|
||||
# Checks for changes to frontend files before running the checks.
|
||||
# When manually triggered or when called from another workflow, always runs the checks.
|
||||
# If always_run is true, always runs the checks.
|
||||
|
||||
name: 'frontend checks'
|
||||
|
||||
@@ -16,7 +16,19 @@ on:
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@@ -30,7 +42,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
@@ -39,30 +51,30 @@ jobs:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: tsc
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: dpdm
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: eslint
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: prettier
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: knip
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:knip'
|
||||
shell: bash
|
||||
|
||||
20
.github/workflows/frontend-tests.yml
vendored
20
.github/workflows/frontend-tests.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# Runs frontend tests.
|
||||
#
|
||||
# Checks for changes to frontend files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
# If always_run is true, always runs the tests.
|
||||
|
||||
name: 'frontend tests'
|
||||
|
||||
@@ -16,7 +16,19 @@ on:
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@@ -30,7 +42,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
@@ -39,10 +51,10 @@ jobs:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: vitest
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm test:no-watch'
|
||||
shell: bash
|
||||
|
||||
24
.github/workflows/python-checks.yml
vendored
24
.github/workflows/python-checks.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# Runs python code quality checks.
|
||||
#
|
||||
# Checks for changes to python files before running the checks.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
# If always_run is true, always runs the checks.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
@@ -18,7 +18,19 @@ on:
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
@@ -29,7 +41,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
@@ -41,7 +53,7 @@ jobs:
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
@@ -49,16 +61,16 @@ jobs:
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff format --check .
|
||||
shell: bash
|
||||
|
||||
22
.github/workflows/python-tests.yml
vendored
22
.github/workflows/python-tests.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# Runs python tests on a matrix of python versions and platforms.
|
||||
#
|
||||
# Checks for changes to python files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
# If always_run is true, always runs the tests.
|
||||
|
||||
name: 'python tests'
|
||||
|
||||
@@ -16,7 +16,19 @@ on:
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
@@ -63,7 +75,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
@@ -75,7 +87,7 @@ jobs:
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -83,12 +95,12 @@ jobs:
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pytest
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -30,15 +30,23 @@ jobs:
|
||||
|
||||
frontend-checks:
|
||||
uses: ./.github/workflows/frontend-checks.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
frontend-tests:
|
||||
uses: ./.github/workflows/frontend-tests.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
python-checks:
|
||||
uses: ./.github/workflows/python-checks.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
python-tests:
|
||||
uses: ./.github/workflows/python-tests.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
@@ -58,6 +66,8 @@ jobs:
|
||||
environment:
|
||||
name: testpypi
|
||||
url: https://test.pypi.org/p/invokeai
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
@@ -85,6 +95,8 @@ jobs:
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/invokeai
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
|
||||
27
Makefile
27
Makefile
@@ -6,17 +6,18 @@ default: help
|
||||
help:
|
||||
@echo Developer commands:
|
||||
@echo
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
@@ -41,6 +42,10 @@ mypy-all:
|
||||
test:
|
||||
pytest ./tests
|
||||
|
||||
# Update config docstring
|
||||
update-config-docstring:
|
||||
python scripts/update_config_docstring.py
|
||||
|
||||
# Install the pnpm modules needed for the front end
|
||||
frontend-install:
|
||||
rm -rf invokeai/frontend/web/node_modules
|
||||
|
||||
@@ -16,11 +16,6 @@ model. These are the:
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||
are able to retrieve metadata from online model repositories,
|
||||
transform them into Pydantic models, and cache them to the InvokeAI
|
||||
SQL database.
|
||||
|
||||
* _DownloadQueueServiceBase_
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
@@ -382,17 +377,14 @@ functionality:
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
(currently only HuggingFace).
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
@@ -436,7 +428,6 @@ required parameters:
|
||||
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||
| `download_queue` | DownloadQueueServiceBase | Download queue object |
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
@@ -580,33 +571,7 @@ The `AnyHttpUrl` class can be imported from `pydantic.networks`.
|
||||
|
||||
Ordinarily, no metadata is retrieved from these sources. However,
|
||||
there is special-case code in the installer that looks for HuggingFace
|
||||
and Civitai URLs and fetches the corresponding model metadata from
|
||||
the corresponding repo.
|
||||
|
||||
#### CivitaiModelSource
|
||||
|
||||
This is used for a model that is hosted by the Civitai web site.
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `version_id` | int | None | The ID of the particular version of the desired model. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
Civitai has two model IDs, both of which are integers. The `model_id`
|
||||
corresponds to a collection of model versions that may different in
|
||||
arbitrary ways, such as derivation from different checkpoint training
|
||||
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
|
||||
`version_id` points to a specific version. Please use the latter.
|
||||
|
||||
Some Civitai models require an access token to download. These can be
|
||||
generated from the Civitai profile page of a logged-in
|
||||
account. Somewhat annoyingly, if you fail to provide the access token
|
||||
when downloading a model that needs it, Civitai generates a redirect
|
||||
to a login page rather than a 403 Forbidden error. The installer
|
||||
attempts to catch this event and issue an informative error
|
||||
message. Otherwise you will get an "unrecognized model suffix" error
|
||||
when the model prober tries to identify the type of the HTML login
|
||||
page.
|
||||
and fetches the corresponding model metadata from the corresponding repo.
|
||||
|
||||
#### HFModelSource
|
||||
|
||||
@@ -1253,9 +1218,9 @@ queue and have not yet reached a terminal state.
|
||||
|
||||
The modules found under `invokeai.backend.model_manager.metadata`
|
||||
provide a straightforward API for fetching model metadatda from online
|
||||
repositories. Currently two repositories are supported: HuggingFace
|
||||
and Civitai. However, the modules are easily extended for additional
|
||||
repos, provided that they have defined APIs for metadata access.
|
||||
repositories. Currently only HuggingFace is supported. However, the
|
||||
modules are easily extended for additional repos, provided that they
|
||||
have defined APIs for metadata access.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
@@ -1267,37 +1232,16 @@ model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CivitaiMetadata
|
||||
ModelMetadataStore,
|
||||
)
|
||||
# to access the initialized sql database
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
civitai = CivitaiMetadataFetch()
|
||||
hf = HuggingFaceMetadataFetch()
|
||||
|
||||
# fetch the metadata
|
||||
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||
model_metadata = hf.from_id("<repo_id>")
|
||||
|
||||
# get some common metadata fields
|
||||
author = model_metadata.author
|
||||
tags = model_metadata.tags
|
||||
|
||||
# get some Civitai-specific fields
|
||||
assert isinstance(model_metadata, CivitaiMetadata)
|
||||
|
||||
trained_words = model_metadata.trained_words
|
||||
base_model = model_metadata.base_model_trained_on
|
||||
thumbnail = model_metadata.thumbnail_url
|
||||
|
||||
# cache the metadata to the database using the key corresponding to
|
||||
# an existing model config record in the `model_config` table
|
||||
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||
|
||||
# now we can search the database by tag, author or model name
|
||||
# matches will contain a list of model keys that match the search
|
||||
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||
assert isinstance(model_metadata, HuggingFaceMetadata)
|
||||
```
|
||||
|
||||
### Structure of the Metadata objects
|
||||
@@ -1334,52 +1278,14 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | int | Civitai model id |
|
||||
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||
| `created` | datetime | Date this version of the model was created |
|
||||
| `updated` | datetime | Date this version of the model was last updated |
|
||||
| `published` | datetime | Date this version of the model was published to Civitai |
|
||||
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||
|
||||
Note that `weight_min` and `weight_max` are not currently populated
|
||||
and take the default values of (-1.0, +2.0). The issue is that these
|
||||
values aren't part of the structured data but appear in the text
|
||||
description. Some regular expression or LLM coding may be able to
|
||||
extract these values.
|
||||
|
||||
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||
correspond to our `ModelType` enum.
|
||||
|
||||
`CivitaiMetadata` also defines some convenience properties relating to
|
||||
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||
`allow_derivatives` and `allow_different_license`.
|
||||
|
||||
#### `AnyModelRepoMetadata`
|
||||
|
||||
This is a discriminated Union of `CivitaiMetadata` and
|
||||
`HuggingFaceMetadata`.
|
||||
This is a discriminated Union of `HuggingFaceMetadata`.
|
||||
|
||||
### Fetching Metadata from Online Repos
|
||||
|
||||
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||
retrieve metadata from their corresponding repositories and return
|
||||
The `HuggingFaceMetadataFetch` class will
|
||||
retrieve metadata from its corresponding repository and return
|
||||
`AnyModelRepoMetadata` objects. Their base class
|
||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||
@@ -1397,96 +1303,17 @@ provide a `requests.Session` argument. This allows you to customize
|
||||
the low-level HTTP fetch requests and is used, for instance, in the
|
||||
testing suite to avoid hitting the internet.
|
||||
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
The HuggingFace fetcher subclass add additional repo-specific fetching methods:
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
`HuggingFaceMetadata` object directly.
|
||||
|
||||
#### CivitaiMetadataFetch
|
||||
|
||||
This adds the following methods:
|
||||
|
||||
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||
default version of the model, and then retrieves the metadata for
|
||||
that version, returning a `CivitaiMetadata` object directly.
|
||||
|
||||
`from_civitai_versionid()` This takes the ID of a model version and
|
||||
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
model key in the model_config table_. There is a foreign key integrity
|
||||
constraint between the `model_config.id` field and the
|
||||
`model_metadata.id` field such that if you attempt to save metadata
|
||||
under an unknown key, the attempt will result in an
|
||||
`UnknownModelException`. Likewise, when a model is deleted from
|
||||
`model_config`, the deletion of the corresponding metadata record will
|
||||
be triggered.
|
||||
|
||||
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||
table.
|
||||
|
||||
To create the storage object, initialize it with the InvokeAI
|
||||
`SqliteDatabase` object. This is often done this way:
|
||||
|
||||
```
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
```
|
||||
|
||||
You can then access the storage with the following methods:
|
||||
|
||||
#### `add_metadata(key, metadata)`
|
||||
|
||||
Add the metadata using a previously-defined model key.
|
||||
|
||||
There is currently no `delete_metadata()` method. The metadata will
|
||||
persist until the matching config is deleted from the `model_config`
|
||||
table.
|
||||
|
||||
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||
|
||||
Retrieve the metadata corresponding to the model key.
|
||||
|
||||
#### `update_metadata(key, new_metadata)`
|
||||
|
||||
Update an existing metadata record with new metadata.
|
||||
|
||||
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||
|
||||
Given a set of tags, find models that are tagged with them. If
|
||||
multiple tags are provided then a matching model must be tagged with
|
||||
*all* the tags in the set. This method returns a set of model keys and
|
||||
is intended to be used in conjunction with the `ModelRecordService`:
|
||||
|
||||
```
|
||||
model_config_store = ApiDependencies.invoker.services.model_records
|
||||
matches = metadata_store.search_by_tag({'license:other'})
|
||||
models = [model_config_store.get(x) for x in matches]
|
||||
```
|
||||
|
||||
#### `search_by_name(name: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given name and return a
|
||||
set of keys to the corresponding model config objects.
|
||||
|
||||
#### `search_by_author(author: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
The `ModelConfigBase` stores this response in the `source_api_response` field
|
||||
as a JSON blob.
|
||||
|
||||
***
|
||||
|
||||
|
||||
133
docs/contributing/frontend/OVERVIEW.md
Normal file
133
docs/contributing/frontend/OVERVIEW.md
Normal file
@@ -0,0 +1,133 @@
|
||||
# Invoke UI
|
||||
|
||||
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
|
||||
|
||||
## Dev environment
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install [node] and [pnpm].
|
||||
1. Run `pnpm i` to install all packages.
|
||||
|
||||
#### Run in dev mode
|
||||
|
||||
1. From `invokeai/frontend/web/`, run `pnpm dev`.
|
||||
1. From repo root, run `python scripts/invokeai-web.py`.
|
||||
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
|
||||
|
||||
### Package scripts
|
||||
|
||||
- `dev`: run the frontend in dev mode, enabling hot reloading
|
||||
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
|
||||
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
|
||||
- `lint:dpdm`: check circular dependencies
|
||||
- `lint:eslint`: check code quality
|
||||
- `lint:prettier`: check code formatting
|
||||
- `lint:tsc`: check type issues
|
||||
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
|
||||
- `lint`: run all checks concurrently
|
||||
- `fix`: run `eslint` and `prettier`, fixing fixable issues
|
||||
|
||||
### Type generation
|
||||
|
||||
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
|
||||
|
||||
The generated types are committed to the repo in [schema.ts].
|
||||
|
||||
```sh
|
||||
# from the repo root, start the server
|
||||
python scripts/invokeai-web.py
|
||||
# from invokeai/frontend/web/, run the script
|
||||
pnpm typegen
|
||||
```
|
||||
|
||||
### Localization
|
||||
|
||||
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
|
||||
|
||||
Only the English source strings should be changed on this repo.
|
||||
|
||||
### VSCode
|
||||
|
||||
#### Example debugger config
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"name": "Invoke UI",
|
||||
"url": "http://localhost:5173",
|
||||
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Remote dev
|
||||
|
||||
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
|
||||
|
||||
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
|
||||
|
||||
```sh
|
||||
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
|
||||
```
|
||||
|
||||
## Contributing Guidelines
|
||||
|
||||
Thanks for your interest in contributing to the Invoke Web UI!
|
||||
|
||||
Please follow these guidelines when contributing.
|
||||
|
||||
### Check in before investing your time
|
||||
|
||||
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
|
||||
|
||||
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
|
||||
|
||||
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
|
||||
|
||||
### Code conventions
|
||||
|
||||
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
|
||||
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
|
||||
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
|
||||
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
|
||||
- Please add comments describing the "why", not the "how" (unless it is really arcane).
|
||||
|
||||
### Commit format
|
||||
|
||||
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
|
||||
|
||||
- `chore(ui): bump deps`
|
||||
- `chore(ui): lint`
|
||||
- `feat(ui): add some cool new feature`
|
||||
- `fix(ui): fix some bug`
|
||||
|
||||
### Submitting a PR
|
||||
|
||||
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
|
||||
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
|
||||
- Fill out the PR form when creating the PR.
|
||||
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
|
||||
- If a section isn't relevant, delete it. There are no UI tests at this time.
|
||||
|
||||
## Other docs
|
||||
|
||||
- [Workflows - Design and Implementation]
|
||||
- [State Management]
|
||||
|
||||
[node]: https://nodejs.org/en/download/
|
||||
[pnpm]: https://github.com/pnpm/pnpm
|
||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||
[i18next]: https://github.com/i18next/react-i18next
|
||||
[Weblate]: https://hosted.weblate.org/engage/invokeai/
|
||||
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
|
||||
[Type generation]: #type-generation
|
||||
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
|
||||
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
|
||||
[Workflows - Design and Implementation]: ./WORKFLOWS.md
|
||||
[State Management]: ./STATE_MGMT.md
|
||||
@@ -1,40 +1,5 @@
|
||||
# Workflows - Design and Implementation
|
||||
|
||||
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
|
||||
|
||||
<!-- code_chunk_output -->
|
||||
|
||||
- [Workflows - Design and Implementation](#workflows---design-and-implementation)
|
||||
- [Design](#design)
|
||||
- [Linear UI](#linear-ui)
|
||||
- [Workflow Editor](#workflow-editor)
|
||||
- [Workflows](#workflows)
|
||||
- [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph)
|
||||
- [Nodes vs Invocations](#nodes-vs-invocations)
|
||||
- [Workflow Linear View](#workflow-linear-view)
|
||||
- [OpenAPI Schema](#openapi-schema)
|
||||
- [Field Instances and Templates](#field-instances-and-templates)
|
||||
- [Stateful vs Stateless Fields](#stateful-vs-stateless-fields)
|
||||
- [Collection and Polymorphic Fields](#collection-and-polymorphic-fields)
|
||||
- [Implementation](#implementation)
|
||||
- [zod Schemas and Types](#zod-schemas-and-types)
|
||||
- [OpenAPI Schema Parsing](#openapi-schema-parsing)
|
||||
- [Parsing Field Types](#parsing-field-types)
|
||||
- [Primitive Types](#primitive-types)
|
||||
- [Complex Types](#complex-types)
|
||||
- [Collection Types](#collection-types)
|
||||
- [Collection or Scalar Types](#collection-or-scalar-types)
|
||||
- [Optional Fields](#optional-fields)
|
||||
- [Building Field Input Templates](#building-field-input-templates)
|
||||
- [Building Field Output Templates](#building-field-output-templates)
|
||||
- [Managing reactflow State](#managing-reactflow-state)
|
||||
- [Building Nodes and Edges](#building-nodes-and-edges)
|
||||
- [Building a Workflow](#building-a-workflow)
|
||||
- [Loading a Workflow](#loading-a-workflow)
|
||||
- [Workflow Migrations](#workflow-migrations)
|
||||
|
||||
<!-- /code_chunk_output -->
|
||||
|
||||
> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code.
|
||||
|
||||
InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images.
|
||||
@@ -152,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
|
||||
|
||||
"Custom" fields will always be treated as stateless fields.
|
||||
|
||||
##### Collection and Polymorphic Fields
|
||||
##### Collection and Scalar Fields
|
||||
|
||||
Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field.
|
||||
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
|
||||
|
||||
If a field is annotated in python as a list, its field type is parsed and flagged as a collection type (e.g. `list[int]`).
|
||||
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
|
||||
|
||||
If it is annotated as a union of a type and list, the type will be flagged as a polymorphic type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
|
||||
## Implementation
|
||||
|
||||
@@ -338,13 +303,13 @@ Migration logic is in [migrations.ts].
|
||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
|
||||
[buildWorkflow.ts]: ../src/features/nodes/util/workflow/buildWorkflow.ts
|
||||
[nodesSlice.ts]: ../src/features/nodes/store/nodesSlice.ts
|
||||
[buildLinearTextToImageGraph.ts]: ../src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
|
||||
[buildNodesGraph.ts]: ../src/features/nodes/util/graph/buildNodesGraph.ts
|
||||
[buildInvocationNode.ts]: ../src/features/nodes/util/node/buildInvocationNode.ts
|
||||
[validateWorkflow.ts]: ../src/features/nodes/util/workflow/validateWorkflow.ts
|
||||
[migrations.ts]: ../src/features/nodes/util/workflow/migrations.ts
|
||||
[parseSchema.ts]: ../src/features/nodes/util/schema/parseSchema.ts
|
||||
[buildFieldInputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldInputTemplate.ts
|
||||
[buildFieldOutputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldOutputTemplate.ts
|
||||
[buildWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts
|
||||
[nodesSlice.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
|
||||
[buildLinearTextToImageGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
|
||||
[buildNodesGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
|
||||
[buildInvocationNode.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts
|
||||
[validateWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts
|
||||
[migrations.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts
|
||||
[parseSchema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts
|
||||
[buildFieldInputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
|
||||
[buildFieldOutputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts
|
||||
@@ -31,18 +31,18 @@ be referred to as ROOT.
|
||||
To find its root directory, InvokeAI uses the following recipe:
|
||||
|
||||
1. It first looks for the argument `--root <path>` on the command line
|
||||
it was launched from, and uses the indicated path if present.
|
||||
it was launched from, and uses the indicated path if present.
|
||||
|
||||
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
||||
the directory path found there if present.
|
||||
the directory path found there if present.
|
||||
|
||||
3. If neither of these are present, then InvokeAI looks for the
|
||||
folder containing the `.venv` Python virtual environment directory for
|
||||
the currently active environment. This directory is checked for files
|
||||
expected inside the InvokeAI root before it is used.
|
||||
folder containing the `.venv` Python virtual environment directory for
|
||||
the currently active environment. This directory is checked for files
|
||||
expected inside the InvokeAI root before it is used.
|
||||
|
||||
4. Finally, InvokeAI looks for a directory in the current user's home
|
||||
directory named `invokeai`.
|
||||
directory named `invokeai`.
|
||||
|
||||
#### Reading the InvokeAI Configuration File
|
||||
|
||||
@@ -149,104 +149,75 @@ usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS
|
||||
|
||||
## The Configuration Settings
|
||||
|
||||
The configuration settings are divided into several distinct
|
||||
groups in `invokeia.yaml`:
|
||||
The config is managed by the `InvokeAIAppConfig` class, which is a pydantic model. The below docs are autogenerated from the class.
|
||||
|
||||
### Web Server
|
||||
When editing your `invokeai.yaml` file, you'll need to put settings under their appropriate group. The group for each setting is denoted in the table below.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
|
||||
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
|
||||
Following the table are additional explanations for certain settings.
|
||||
|
||||
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.config.config_default.InvokeAIAppConfig
|
||||
options:
|
||||
heading_level: 3
|
||||
members: false
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Features
|
||||
### Model Marketplace API Keys
|
||||
|
||||
These configuration settings allow you to enable and disable various InvokeAI features:
|
||||
Some model marketplaces require an API key to download models. You can provide a URL pattern and appropriate token in your `invokeai.yaml` file to provide that API key.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||
The pattern can be any valid regex (you may need to surround the pattern with quotes):
|
||||
|
||||
### Generation
|
||||
```yaml
|
||||
InvokeAI:
|
||||
Model Install:
|
||||
remote_api_tokens:
|
||||
# Any URL containing `models.com` will automatically use `your_models_com_token`
|
||||
- url_regex: models.com
|
||||
token: your_models_com_token
|
||||
# Any URL matching this contrived regex will use `some_other_token`
|
||||
- url_regex: '^[a-z]{3}whatever.*\.com$'
|
||||
token: some_other_token
|
||||
```
|
||||
|
||||
These options tune InvokeAI's memory and performance characteristics.
|
||||
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
|
||||
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
|
||||
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
|
||||
### Model Hashing
|
||||
|
||||
### Device
|
||||
Models are hashed during installation, providing a stable identifier for models across all platforms. The default algorithm is `blake3`, with a multi-threaded implementation.
|
||||
|
||||
These options configure the generation execution device.
|
||||
If your models are stored on a spinning hard drive, we suggest using `blake3_single`, the single-threaded implementation. The hashes are the same, but it's much faster on spinning disks.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
|
||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||
```yaml
|
||||
InvokeAI:
|
||||
Model Install:
|
||||
hashing_algorithm: blake3_single
|
||||
```
|
||||
|
||||
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing entirely by setting the algorithm to `random`.
|
||||
|
||||
```yaml
|
||||
InvokeAI:
|
||||
Model Install:
|
||||
hashing_algorithm: random
|
||||
```
|
||||
|
||||
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than `blake3`.
|
||||
|
||||
### Paths
|
||||
|
||||
These options set the paths of various directories and files used by
|
||||
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
|
||||
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||
InvokeAI. Relative paths are interpreted relative to the root directory, so
|
||||
if root is `/home/fred/invokeai` and the path is
|
||||
`autoimport/main`, then the corresponding directory will be located at
|
||||
`/home/fred/invokeai/autoimport/main`.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
|
||||
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
|
||||
|
||||
Note that the autoimport directories will be searched recursively,
|
||||
Note that the autoimport directory will be searched recursively,
|
||||
allowing you to organize the models into folders and subfolders in any
|
||||
way you wish. In addition, while we have split up autoimport
|
||||
directories by the type of model they contain, this isn't
|
||||
necessary. You can combine different model types in the same folder
|
||||
and InvokeAI will figure out what they are. So you can easily use just
|
||||
one autoimport directory by commenting out the unneeded paths:
|
||||
|
||||
```
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
# lora_dir: null
|
||||
# embedding_dir: null
|
||||
# controlnet_dir: null
|
||||
```
|
||||
way you wish.
|
||||
|
||||
### Logging
|
||||
|
||||
These settings control the information, warning, and debugging
|
||||
messages printed to the console log while InvokeAI is running:
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
|
||||
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
|
||||
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
|
||||
|
||||
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
||||
|
||||
```
|
||||
@@ -256,9 +227,9 @@ Several different log handler destinations are available, and multiple destinati
|
||||
- file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||
- `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||
|
||||
* `syslog` is only available on Linux and Macintosh systems. It uses
|
||||
- `syslog` is only available on Linux and Macintosh systems. It uses
|
||||
the operating system's "syslog" facility to write log file entries
|
||||
locally or to a remote logging machine. `syslog` offers a variety
|
||||
of configuration options:
|
||||
@@ -271,7 +242,7 @@ Several different log handler destinations are available, and multiple destinati
|
||||
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
||||
```
|
||||
|
||||
* `http` can be used to log to a remote web server. The server must be
|
||||
- `http` can be used to log to a remote web server. The server must be
|
||||
properly configured to receive and act on log messages. The option
|
||||
accepts the URL to the web server, and a `method` argument
|
||||
indicating whether the message should be submitted using the GET or
|
||||
@@ -283,7 +254,7 @@ Several different log handler destinations are available, and multiple destinati
|
||||
|
||||
The `log_format` option provides several alternative formats:
|
||||
|
||||
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||
* `plain` - same as above, but monochrome text only
|
||||
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
||||
- `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||
- `plain` - same as above, but monochrome text only
|
||||
- `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||
- `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
||||
|
||||
35
docs/features/DATABASE.md
Normal file
35
docs/features/DATABASE.md
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
title: Database
|
||||
---
|
||||
|
||||
# Invoke's SQLite Database
|
||||
|
||||
Invoke uses a SQLite database to store image, workflow, model, and execution data.
|
||||
|
||||
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
|
||||
|
||||
Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
|
||||
|
||||
## Database Backup
|
||||
|
||||
Backing up your database is very simple. Invoke's data is stored in an `$INVOKEAI_ROOT` directory - where your `invoke.sh`/`invoke.bat` and `invokeai.yaml` files live.
|
||||
|
||||
To back up your database, copy the `invokeai.db` file from `$INVOKEAI_ROOT/databases/invokeai.db` to somewhere safe.
|
||||
|
||||
If anything comes up during prelease testing, you can simply copy your backup back into `$INVOKEAI_ROOT/databases/`.
|
||||
|
||||
## In-Memory Database
|
||||
|
||||
SQLite can run on an in-memory database. Your existing database is untouched when this mode is enabled, but your existing data won't be accessible.
|
||||
|
||||
This is very useful for testing, as there is no chance of a database change modifying your "physical" database.
|
||||
|
||||
To run Invoke with a memory database, edit your `invokeai.yaml` file, and add `use_memory_db: true` to the `Paths:` stanza:
|
||||
|
||||
```yaml
|
||||
InvokeAI:
|
||||
Development:
|
||||
use_memory_db: true
|
||||
```
|
||||
|
||||
Delete this line (or set it to `false`) to use your main database.
|
||||
@@ -22,6 +22,24 @@ class MyInvocation(BaseInvocation):
|
||||
...
|
||||
```
|
||||
|
||||
The full API is documented below.
|
||||
|
||||
## Invocation Mixins
|
||||
|
||||
Two important mixins are provided to facilitate working with metadata and gallery boards.
|
||||
|
||||
### `WithMetadata`
|
||||
|
||||
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
|
||||
|
||||
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
|
||||
|
||||
### `WithBoard`
|
||||
|
||||
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
|
||||
|
||||
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
||||
options:
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@@ -246,6 +248,40 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
class HuggingFaceModels(BaseModel):
|
||||
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
|
||||
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/hugging_face",
|
||||
operation_id="get_hugging_face_models",
|
||||
responses={
|
||||
200: {"description": "Hugging Face repo scanned successfully"},
|
||||
400: {"description": "Invalid hugging face repo"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=HuggingFaceModels,
|
||||
)
|
||||
async def get_hugging_face_models(
|
||||
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||
) -> HuggingFaceModels:
|
||||
try:
|
||||
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
|
||||
except UnknownMetadataException:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No HuggingFace repository found",
|
||||
)
|
||||
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
return HuggingFaceModels(
|
||||
urls=metadata.ckpt_urls,
|
||||
is_diffusers=metadata.is_diffusers,
|
||||
)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
@@ -156,17 +158,19 @@ def custom_openapi() -> dict[str, Any]:
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
|
||||
@@ -31,10 +31,11 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -54,7 +55,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@@ -90,7 +91,9 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
@@ -571,7 +574,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
@@ -580,13 +583,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
default="small", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
offload: bool = InputField(default=False)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
||||
@@ -39,13 +39,15 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -86,7 +88,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
|
||||
@@ -10,8 +10,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -20,8 +20,8 @@ from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@@ -54,8 +54,12 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
@@ -93,7 +97,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
||||
@@ -66,7 +66,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
@@ -76,7 +75,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelField, UNetField, VAEField
|
||||
from .model import ModelIdentifierField, UNetField, VAEField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@@ -245,7 +244,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelField,
|
||||
scheduler_info: ModelIdentifierField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
@@ -384,12 +383,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
@@ -684,7 +677,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@@ -844,14 +837,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
@@ -1025,7 +1018,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
|
||||
@@ -20,8 +20,8 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@@ -31,20 +31,10 @@ class MetadataItemField(BaseModel):
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class ModelMetadataField(BaseModel):
|
||||
"""Model Metadata Field"""
|
||||
|
||||
key: str
|
||||
hash: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@@ -52,19 +42,16 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelMetadataField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
description="The weight given to the IP-Adapter",
|
||||
)
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||
image: ImageField = Field(description="The control image.")
|
||||
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@@ -77,7 +64,8 @@ class T2IAdapterMetadataField(BaseModel):
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@@ -178,7 +166,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
@@ -197,7 +185,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[ModelMetadataField] = InputField(
|
||||
vae: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@@ -228,7 +216,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||
refiner_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -16,33 +16,52 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelField(BaseModel):
|
||||
key: str = Field(description="Key of the model")
|
||||
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||
class ModelIdentifierField(BaseModel):
|
||||
key: str = Field(description="The model's unique key")
|
||||
hash: str = Field(description="The model's BLAKE3 hash")
|
||||
name: str = Field(description="The model's name")
|
||||
base: BaseModelType = Field(description="The model's base model type")
|
||||
type: ModelType = Field(description="The model's type")
|
||||
submodel_type: Optional[SubModelType] = Field(
|
||||
description="The submodel to load, if this is a main model", default=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelIdentifierField":
|
||||
return cls(
|
||||
key=config.key,
|
||||
hash=config.hash,
|
||||
name=config.name,
|
||||
base=config.base,
|
||||
type=config.type,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelField = Field(description="Info to load lora model")
|
||||
lora: ModelIdentifierField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelField = Field(description="Info to load vae submodel")
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@@ -84,7 +103,9 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
@@ -117,7 +138,9 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@@ -186,7 +209,9 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@@ -258,10 +283,8 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
|
||||
@@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import CLIPField, ModelField, UNetField, VAEField
|
||||
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@@ -72,7 +72,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelField = InputField(
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@@ -9,15 +9,15 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@@ -52,11 +52,12 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelField = InputField(
|
||||
t2i_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
|
||||
@@ -17,7 +17,8 @@ from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||
@@ -62,6 +63,22 @@ class InvokeAISettings(BaseSettings):
|
||||
assert isinstance(category, str)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = {}
|
||||
if isinstance(value, BaseModel):
|
||||
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
|
||||
field_dict[type][category][name] = dump
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
if not value or len(value) == 0:
|
||||
continue
|
||||
primitive = isinstance(value[0], get_args(DictKeyType))
|
||||
if not primitive:
|
||||
val_list: List[Dict[str, Any]] = []
|
||||
for list_val in value:
|
||||
if isinstance(list_val, BaseModel):
|
||||
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
|
||||
val_list.append(dump)
|
||||
field_dict[type][category][name] = val_list
|
||||
continue
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
|
||||
@@ -170,14 +170,17 @@ two configs are kept in separate sections of the config file:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
|
||||
from .config_base import InvokeAISettings
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
@@ -196,17 +199,87 @@ class Categories(object):
|
||||
Paths: JsonDict = {"category": "Paths"}
|
||||
Logging: JsonDict = {"category": "Logging"}
|
||||
Development: JsonDict = {"category": "Development"}
|
||||
Other: JsonDict = {"category": "Other"}
|
||||
CLIArgs: JsonDict = {"category": "CLIArgs"}
|
||||
ModelInstall: JsonDict = {"category": "Model Install"}
|
||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||
Device: JsonDict = {"category": "Device"}
|
||||
Generation: JsonDict = {"category": "Generation"}
|
||||
Queue: JsonDict = {"category": "Queue"}
|
||||
Nodes: JsonDict = {"category": "Nodes"}
|
||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||
Deprecated: JsonDict = {"category": "Deprecated"}
|
||||
|
||||
|
||||
class URLRegexToken(BaseModel):
|
||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
||||
token: str = Field(description="Token to use when the URL matches the regex")
|
||||
|
||||
@field_validator("url_regex")
|
||||
@classmethod
|
||||
def validate_url_regex(cls, v: str) -> str:
|
||||
"""Validate that the value is a valid regex."""
|
||||
try:
|
||||
re.compile(v)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex: {e}")
|
||||
return v
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Configuration object for InvokeAI App."""
|
||||
"""Invoke App Configuration
|
||||
|
||||
Attributes:
|
||||
host: **Web Server**: IP address to bind to. Use `0.0.0.0` to serve to your local network.
|
||||
port: **Web Server**: Port to bind to.
|
||||
allow_origins: **Web Server**: Allowed CORS origins.
|
||||
allow_credentials: **Web Server**: Allow CORS credentials.
|
||||
allow_methods: **Web Server**: Methods allowed for CORS.
|
||||
allow_headers: **Web Server**: Headers allowed for CORS.
|
||||
ssl_certfile: **Web Server**: SSL certificate file for HTTPS.
|
||||
ssl_keyfile: **Web Server**: SSL key file for HTTPS.
|
||||
esrgan: **Features**: Enables or disables the upscaling code.
|
||||
internet_available: **Features**: If true, attempt to download models on the fly; otherwise only use local models.
|
||||
log_tokenization: **Features**: Enable logging of parsed prompt tokens.
|
||||
patchmatch: **Features**: Enable patchmatch inpaint code.
|
||||
ignore_missing_core_models: **Features**: Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.
|
||||
root: **Paths**: The InvokeAI runtime root directory.
|
||||
autoimport_dir: **Paths**: Path to a directory of models files to be imported on startup.
|
||||
models_dir: **Paths**: Path to the models directory.
|
||||
convert_cache_dir: **Paths**: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
legacy_conf_dir: **Paths**: Path to directory of legacy checkpoint config files.
|
||||
db_dir: **Paths**: Path to InvokeAI databases directory.
|
||||
outdir: **Paths**: Path to directory for outputs.
|
||||
custom_nodes_dir: **Paths**: Path to directory for custom nodes.
|
||||
from_file: **Paths**: Take command input from the indicated file (command-line client only).
|
||||
log_handlers: **Logging**: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: **Logging**: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.
|
||||
log_level: **Logging**: Emit logging messages at this level or higher.
|
||||
log_sql: **Logging**: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
|
||||
use_memory_db: **Development**: Use in-memory database. Useful for development.
|
||||
dev_reload: **Development**: Automatically reload when Python sources are changed. Does not reload node definitions.
|
||||
profile_graphs: **Development**: Enable graph profiling using `cProfile`.
|
||||
profile_prefix: **Development**: An optional prefix for profile output files.
|
||||
profiles_dir: **Development**: Path to profiles output directory.
|
||||
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
|
||||
hashing_algorithm: **Model Install**: Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.
|
||||
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
||||
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
|
||||
convert_cache: **Model Cache**: Maximum size of on-disk converted models cache (GB)
|
||||
lazy_offload: **Model Cache**: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: **Model Cache**: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: **Device**: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
|
||||
precision: **Device**: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
|
||||
sequential_guidance: **Generation**: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: **Generation**: Attention type.
|
||||
attention_slice_size: **Generation**: Slice size, valid when attention_type=="sliced".
|
||||
force_tiled_decode: **Generation**: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
png_compress_level: **Generation**: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: **Queue**: Maximum number of items in the session queue.
|
||||
allow_nodes: **Nodes**: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: **Nodes**: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: **Nodes**: How many cached nodes to keep in memory.
|
||||
"""
|
||||
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
|
||||
@@ -215,91 +288,98 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.", json_schema_extra=Categories.WebServer)
|
||||
port : int = Field(default=9090, description="Port to bind to.", json_schema_extra=Categories.WebServer)
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins.", json_schema_extra=Categories.WebServer)
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials.", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS.", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS.", json_schema_extra=Categories.WebServer)
|
||||
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
||||
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
|
||||
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
|
||||
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file for HTTPS.", json_schema_extra=Categories.WebServer)
|
||||
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file for HTTPS.", json_schema_extra=Categories.WebServer)
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||
esrgan : bool = Field(default=True, description="Enables or disables the upscaling code.", json_schema_extra=Categories.Features)
|
||||
# TODO(psyche): This is not used anywhere.
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models.", json_schema_extra=Categories.Features)
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||
patchmatch : bool = Field(default=True, description="Enable patchmatch inpaint code.", json_schema_extra=Categories.Features)
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.', json_schema_extra=Categories.Features)
|
||||
|
||||
# PATHS
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
root : Optional[Path] = Field(default=None, description='The InvokeAI runtime root directory.', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory.', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files.', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory.', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Path to directory for outputs.', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes.', json_schema_extra=Categories.Paths)
|
||||
# TODO(psyche): This is not used anywhere.
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only).', json_schema_extra=Categories.Paths)
|
||||
|
||||
# LOGGING
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".', json_schema_extra=Categories.Logging)
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.', json_schema_extra=Categories.Logging)
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher.", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.", json_schema_extra=Categories.Logging)
|
||||
|
||||
# Development
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database. Useful for development.', json_schema_extra=Categories.Development)
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed. Does not reload node definitions.", json_schema_extra=Categories.Development)
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling using `cProfile`.", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Path to profiles output directory.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
version : bool = Field(default=False, description="CLI arg - show InvokeAI version and exit.", json_schema_extra=Categories.CLIArgs)
|
||||
|
||||
# CACHE
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB)", json_schema_extra=Categories.ModelCache, )
|
||||
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed.", json_schema_extra=Categories.ModelCache, )
|
||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
# DEVICE
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.", json_schema_extra=Categories.Device)
|
||||
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.", json_schema_extra=Categories.Device)
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.", json_schema_extra=Categories.Generation)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type.", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced".', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.", json_schema_extra=Categories.Queue)
|
||||
|
||||
# NODES
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# MODEL IMPORT
|
||||
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
||||
# MODEL INSTALL
|
||||
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
|
||||
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
||||
default=None,
|
||||
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
||||
json_schema_extra=Categories.ModelInstall
|
||||
)
|
||||
|
||||
# TODO(psyche): Can we just remove these then?
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.Deprecated)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.Deprecated)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.Deprecated)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.Deprecated)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Deprecated)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Deprecated)
|
||||
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
@@ -477,6 +557,53 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
return _find_root()
|
||||
|
||||
@staticmethod
|
||||
def generate_docstrings() -> str:
|
||||
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
|
||||
|
||||
You shouldn't run this manually. Instead, run `scripts/update-config-docstring.py` to update the docstring.
|
||||
A makefile target is also available: `make update-config-docstring`.
|
||||
|
||||
See that script for more information about why this is necessary.
|
||||
"""
|
||||
docstring = ' """Invoke App Configuration\n\n'
|
||||
docstring += " Attributes:"
|
||||
|
||||
field_descriptions: dict[str, list[str]] = {}
|
||||
|
||||
for k, v in InvokeAIAppConfig.model_fields.items():
|
||||
if not isinstance(v.json_schema_extra, dict):
|
||||
# Should never happen
|
||||
continue
|
||||
|
||||
category = v.json_schema_extra.get("category", None)
|
||||
if not isinstance(category, str) or category == "Deprecated":
|
||||
continue
|
||||
if not field_descriptions.get(category):
|
||||
field_descriptions[category] = []
|
||||
field_descriptions[category].append(f" {k}: **{category}**: {v.description}")
|
||||
|
||||
for c in [
|
||||
"Web Server",
|
||||
"Features",
|
||||
"Paths",
|
||||
"Logging",
|
||||
"Development",
|
||||
"CLIArgs",
|
||||
"Model Install",
|
||||
"Model Cache",
|
||||
"Device",
|
||||
"Generation",
|
||||
"Queue",
|
||||
"Nodes",
|
||||
]:
|
||||
docstring += "\n"
|
||||
docstring += "\n".join(field_descriptions[c])
|
||||
|
||||
docstring += '\n """'
|
||||
|
||||
return docstring
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
@@ -80,7 +81,7 @@ class EventServiceBase:
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
@@ -180,6 +181,7 @@ class EventServiceBase:
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
@@ -189,7 +191,8 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(),
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -200,6 +203,7 @@ class EventServiceBase:
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
@@ -209,7 +213,8 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(),
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -254,8 +259,8 @@ class EventServiceBase:
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(),
|
||||
"queue_status": queue_status.model_dump(),
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -405,7 +410,7 @@ class EventServiceBase:
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
|
||||
def emit_model_install_cancelled(self, source: str) -> None:
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
|
||||
@@ -413,7 +418,7 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source},
|
||||
payload={"source": source, "id": id},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
@@ -56,7 +57,12 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
return self._invoker.services.urls.get_model_image_url(model_key)
|
||||
url = self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
@@ -23,5 +22,4 @@ __all__ = [
|
||||
"LocalModelSource",
|
||||
"HFModelSource",
|
||||
"URLModelSource",
|
||||
"CivitaiModelSource",
|
||||
]
|
||||
|
||||
@@ -91,21 +91,6 @@ class LocalModelSource(StringLikeSource):
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class CivitaiModelSource(StringLikeSource):
|
||||
"""A Civitai version id, with optional variant and access token."""
|
||||
|
||||
version_id: int
|
||||
variant: Optional[ModelRepoVariant] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["civitai"] = "civitai"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = str(self.version_id)
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
@@ -146,14 +131,11 @@ class URLModelSource(StringLikeSource):
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
@@ -13,6 +12,7 @@ from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from huggingface_hub import HfFolder
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
@@ -22,7 +22,6 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -34,12 +33,11 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
@@ -47,7 +45,6 @@ from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
@@ -118,6 +115,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise Exception("Attempt to start the installer service twice")
|
||||
self._start_installer_thread()
|
||||
self._remove_dangling_install_dirs()
|
||||
self._migrate_yaml()
|
||||
self.sync_to_config()
|
||||
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
@@ -155,10 +153,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@@ -200,9 +195,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=access_token,
|
||||
access_token=_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
@@ -217,8 +219,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
elif isinstance(source, CivitaiModelSource):
|
||||
install_job = self._import_from_civitai(source, config)
|
||||
elif isinstance(source, HFModelSource):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
@@ -281,20 +281,56 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed: List[str] = []
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def _migrate_yaml(self) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
try:
|
||||
yaml = self._get_yaml()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
yaml_metadata = yaml.pop("__metadata__")
|
||||
yaml_version = yaml_metadata.get("version")
|
||||
|
||||
if yaml_version != "3.0.0":
|
||||
raise ValueError(
|
||||
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
||||
)
|
||||
|
||||
self._logger.info(
|
||||
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
|
||||
)
|
||||
|
||||
if len(db_models) == 0 and len(yaml.items()) != 0:
|
||||
for model_key, stanza in yaml.items():
|
||||
_, _, model_name = str(model_key).split("/")
|
||||
model_path = Path(stanza["path"])
|
||||
if not model_path.is_absolute():
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config["config_path"] = stanza.get("config")
|
||||
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
yaml_path = self._app_config.model_conf_path
|
||||
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||
return []
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._models_installed.clear()
|
||||
search.search(scan_dir)
|
||||
return list(self._models_installed)
|
||||
@@ -306,7 +342,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
model_path = Path(model.path)
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
@@ -314,11 +350,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
model_path = Path(model.path)
|
||||
if model_path.is_dir():
|
||||
rmtree(model_path)
|
||||
else:
|
||||
path.unlink()
|
||||
model_path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
@@ -388,10 +424,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
@@ -455,10 +489,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
@@ -476,13 +510,20 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
old_path = Path(model.path)
|
||||
models_dir = self.app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
try:
|
||||
old_path.relative_to(models_dir)
|
||||
return model
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||
|
||||
if old_path == new_path:
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
model.path = new_path.as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
@@ -540,22 +581,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str:
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
model_path = model_path.resolve()
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
@@ -565,6 +600,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the old yaml config
|
||||
# --------------------------------------------------------------------------------------------
|
||||
def _get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self._app_config.model_conf_path
|
||||
omegaconf = OmegaConf.load(yaml_path)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
return omegaconf
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
@@ -580,16 +625,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
@@ -612,7 +647,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
# URLs from HuggingFace will be handled specially
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
@@ -620,8 +655,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
@@ -638,7 +671,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
source: HFModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@@ -852,12 +885,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
||||
@@ -68,6 +68,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
@@ -82,6 +83,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
@@ -91,6 +93,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
@@ -102,6 +105,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
@@ -110,4 +114,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,12 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@@ -68,7 +73,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
@@ -242,6 +242,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@@ -250,10 +251,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
:param order_by: Result order
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
@@ -272,8 +284,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
@@ -319,7 +333,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
|
||||
@@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
@@ -300,7 +300,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
Args:
|
||||
@@ -314,7 +314,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
|
||||
Args:
|
||||
@@ -361,7 +363,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,8 +4,6 @@ from logging import Logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration3Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
@@ -15,7 +13,6 @@ class Migration3Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_model_config_records(cursor)
|
||||
|
||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
@@ -55,12 +52,6 @@ class Migration3Callback:
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
self._logger.info("Migrating model config records from models.yaml to database")
|
||||
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
|
||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class MigrateModelYamlToDb1:
|
||||
"""
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
|
||||
|
||||
The class has one externally useful method, migrate(), which scans the
|
||||
currently models.yaml file and imports all its entries into invokeai.db.
|
||||
|
||||
Use this way:
|
||||
|
||||
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: Logger
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.cursor = cursor
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
omegaconf = OmegaConf.load(yaml_path)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
return omegaconf
|
||||
|
||||
def migrate(self) -> None:
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
try:
|
||||
yaml = self.get_yaml()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
for model_key, stanza in yaml.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
try:
|
||||
if original_record := self._search_by_path(stanza.path):
|
||||
key = original_record.key
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
except UnknownModelException:
|
||||
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||
|
||||
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
|
||||
return results[0] if results else None
|
||||
|
||||
def _update_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self.cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
|
||||
def _add_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
try:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
@@ -13,9 +13,11 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import download_with_progress_bar
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": {
|
||||
@@ -54,8 +56,9 @@ class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
|
||||
def load_model(self, model_size=Literal["large", "base", "small"]):
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
||||
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
||||
@@ -71,8 +74,6 @@ class DepthAnythingDetector:
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
case _:
|
||||
raise TypeError("Not a supported model")
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
@@ -80,20 +81,20 @@ class DepthAnythingDetector:
|
||||
self.model.to(choose_torch_device())
|
||||
return self.model
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
def __call__(self, image, resolution=512, offload=False):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
image = image[:, :, ::-1] / 255.0
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = image.shape[:2]
|
||||
image = transform({"image": image})["image"]
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(image)
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
@@ -103,7 +104,4 @@ class DepthAnythingDetector:
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
if offload:
|
||||
del self.model
|
||||
|
||||
return depth_map
|
||||
|
||||
@@ -17,7 +17,7 @@ import warnings
|
||||
from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from shutil import copy, get_terminal_size, move
|
||||
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
@@ -929,6 +929,10 @@ def main() -> None:
|
||||
|
||||
errors = set()
|
||||
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
backup_init_file = new_init_file.with_suffix(".bak")
|
||||
if new_init_file.exists():
|
||||
copy(new_init_file, backup_init_file)
|
||||
|
||||
try:
|
||||
# if we do a root migration/upgrade, then we are keeping previous
|
||||
@@ -943,7 +947,6 @@ def main() -> None:
|
||||
install_helper = InstallHelper(config, logger)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
@@ -975,8 +978,17 @@ def main() -> None:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
except Exception:
|
||||
print("An error occurred during installation.")
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Fast hashing of diffusers and checkpoint-style models.
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
ALGORITHM = Literal[
|
||||
HASHING_ALGORITHMS = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
@@ -35,12 +25,15 @@ ALGORITHM = Literal[
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
"blake3_single",
|
||||
"random",
|
||||
]
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
@@ -55,20 +48,29 @@ class ModelHash:
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
def __init__(
|
||||
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
self.algorithm: HASHING_ALGORITHMS = algorithm
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm == "blake3_single":
|
||||
self._hash_file = self._blake3_single
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
elif algorithm == "random":
|
||||
self._hash_file = self._random
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@@ -89,10 +91,12 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
model_path = Path(model_path)
|
||||
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||
prefix = self._get_prefix(self.algorithm)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
return prefix + self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
return prefix + self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
@@ -107,16 +111,16 @@ class ModelHash:
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
# Use ThreadPoolExecutor to hash files in parallel
|
||||
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
|
||||
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
||||
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 to hash the hashes
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
component_hashes.sort()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
@@ -132,17 +136,15 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
|
||||
files.extend([Path(file) for file in file_paths])
|
||||
for dir in dirs:
|
||||
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
@@ -155,7 +157,21 @@ class ModelHash:
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
def _blake3_single(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3()
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
@@ -166,15 +182,24 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm."""
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8 * 1024), b""):
|
||||
hasher.update(chunk)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _random(_file_path: Path) -> str:
|
||||
"""Returns a random string. This is not a hash.
|
||||
|
||||
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
|
||||
return blake3(uuid_string().encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
@@ -186,3 +211,9 @@ class ModelHash:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
||||
@staticmethod
|
||||
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
|
||||
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
|
||||
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"
|
||||
@@ -22,7 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
@@ -129,16 +129,27 @@ class ModelSourceType(str, Enum):
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
|
||||
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
|
||||
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
|
||||
cfg_rescale_multiplier: float | None = Field(
|
||||
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
|
||||
)
|
||||
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
|
||||
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
|
||||
|
||||
|
||||
class ControlAdapterDefaultSettings(BaseModel):
|
||||
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
|
||||
preprocessor: str | None
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
@@ -157,10 +168,6 @@ class ModelConfigBase(BaseModel):
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
@@ -187,10 +194,14 @@ class DiffusersConfigBase(ModelConfigBase):
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
class LoRAConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
@@ -198,10 +209,9 @@ class LoRALyCORISConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
@@ -231,7 +241,13 @@ class VAEDiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
@@ -242,7 +258,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
@@ -275,10 +291,17 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
class MainConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
@@ -288,11 +311,9 @@ class MainCheckpointConfig(CheckpointConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
@@ -321,7 +342,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
@@ -373,6 +394,7 @@ AnyModelConfig = Annotated[
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
|
||||
@@ -76,6 +76,25 @@ logger = InvokeAILogger.get_logger(__name__)
|
||||
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
|
||||
|
||||
|
||||
def install_dependencies():
|
||||
"""
|
||||
Check for, and install, missing model dependencies.
|
||||
"""
|
||||
conversion_models = [
|
||||
"clip-vit-large-patch14",
|
||||
"CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
"stable-diffusion-2-clip",
|
||||
"stable-diffusion-safety-checker",
|
||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||
"bert-base-uncased",
|
||||
]
|
||||
if any(not (CONVERT_MODEL_ROOT / x).exists() for x in conversion_models):
|
||||
logger.warning("Installing missing core safetensor conversion models")
|
||||
from invokeai.backend.install.invokeai_configure import download_conversion_models # noqa
|
||||
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@@ -1697,6 +1716,8 @@ def download_controlnet_from_original_ckpt(
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
||||
install_dependencies()
|
||||
|
||||
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
@@ -1717,6 +1738,7 @@ def convert_ckpt_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
install_dependencies()
|
||||
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
@@ -1736,6 +1758,7 @@ def convert_controlnet_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
install_dependencies()
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
|
||||
@@ -19,7 +19,6 @@ context. Use like this:
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
@@ -92,8 +91,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||
# used for stats collection
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
|
||||
@@ -60,7 +60,7 @@ class ModelLoaderRegistryBase(ABC):
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
class ModelLoaderRegistry(ModelLoaderRegistryBase):
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
"""
|
||||
|
||||
@@ -8,23 +8,19 @@ from invokeai.backend.model_manager.metadata import(
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
|
||||
assert isinstance(data, HuggingFaceMetadata)
|
||||
"""
|
||||
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||
from .fetch import HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
HuggingFaceMetadata,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
@@ -34,8 +30,6 @@ from .metadata_base import (
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"ModelMetadataFetchBase",
|
||||
|
||||
@@ -3,19 +3,14 @@ Initialization file for invokeai.backend.model_manager.metadata.fetch
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_manager.metadata.fetch import (
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
data = HuggingFaceMetadataFetch().from_id("<repo_id>")
|
||||
assert isinstance(data, HuggingFaceMetadata)
|
||||
"""
|
||||
|
||||
from .civitai import CivitaiMetadataFetch
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
from .huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
||||
__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch"]
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the Civitai model repository.
|
||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||
`ModelMetadataFetchBase` base class.
|
||||
|
||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||
version ID corresponds to a specific model, and is the ID accepted by
|
||||
`from_id()`. The model ID corresponds to a family of related models,
|
||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||
`from_civitai_modelid()` method will accept a model ID and return the
|
||||
metadata from the default version within this model set. The default
|
||||
version is the same as what the user sees when they click on a model's
|
||||
thumbnail.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||
|
||||
In the event that the URL points to a model page without the particular version
|
||||
indicated, the default model version is returned. Otherwise, the requested version
|
||||
is returned.
|
||||
"""
|
||||
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
version_id = match.group(2)
|
||||
return self.from_civitai_versionid(int(version_id), int(model_id))
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||
|
||||
:param id: An ID.
|
||||
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
return self.from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return metadata from the default version of the indicated model.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
assert len(primary) == 1
|
||||
primary_file = primary[0]
|
||||
|
||||
url = primary_file["downloadUrl"]
|
||||
if "?" not in url: # work around apparent bug in civitai api
|
||||
metadata_string = ""
|
||||
for key, value in primary_file["metadata"].items():
|
||||
if not value:
|
||||
continue
|
||||
metadata_string += f"&{key}={value}"
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
name=version_json["name"],
|
||||
files=model_files,
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
"""
|
||||
Return a CivitaiMetadata object given a model version id.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
@@ -5,11 +5,10 @@ This module is the base class for subclasses that fetch metadata from model repo
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
|
||||
assert isinstance(data, HuggingFaceMetadata)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
)
|
||||
)
|
||||
|
||||
# diffusers models have a `model_index.json` or `config.json` file
|
||||
is_diffusers = any(str(f.url).endswith(("model_index.json", "config.json")) for f in files)
|
||||
|
||||
# These URLs will be exposed to the user - I think these are the only file types we fully support
|
||||
ckpt_urls = (
|
||||
None
|
||||
if is_diffusers
|
||||
else [
|
||||
f.url
|
||||
for f in files
|
||||
if str(f.url).endswith(
|
||||
(
|
||||
".safetensors",
|
||||
".bin",
|
||||
".pth",
|
||||
".pt",
|
||||
".ckpt",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
id=model_info.id,
|
||||
name=name,
|
||||
files=files,
|
||||
api_response=json.dumps(model_info.__dict__, default=str),
|
||||
is_diffusers=is_diffusers,
|
||||
ckpt_urls=ckpt_urls,
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
||||
@@ -78,20 +78,16 @@ class ModelMetadataWithFiles(ModelMetadataBase):
|
||||
return self.files
|
||||
|
||||
|
||||
class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
|
||||
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
|
||||
description="URLs for all checkpoint format models in the metadata", default=None
|
||||
)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@@ -130,5 +126,5 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
return [x for x in self.files if x.path in paths]
|
||||
|
||||
|
||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
||||
|
||||
@@ -9,12 +9,15 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
MainModelDefaultSettings,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
@@ -23,7 +26,6 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@@ -84,9 +86,6 @@ class ProbeBase(object):
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
|
||||
hasher = ModelHash()
|
||||
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
@@ -115,9 +114,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Probe the model at model_path and return its configuration record.
|
||||
@@ -131,6 +128,8 @@ class ModelProbe(object):
|
||||
if fields is None:
|
||||
fields = {}
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
@@ -160,7 +159,15 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
|
||||
if not fields["default_settings"]:
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||
elif fields["type"] is ModelType.Main:
|
||||
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
@@ -332,6 +339,43 @@ class ModelProbe(object):
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# Probing utilities
|
||||
MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"canny": "canny_image_processor",
|
||||
"mlsd": "mlsd_image_processor",
|
||||
"depth": "depth_anything_image_processor",
|
||||
"bae": "normalbae_image_processor",
|
||||
"normal": "normalbae_image_processor",
|
||||
"sketch": "pidi_image_processor",
|
||||
"scribble": "lineart_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
"pose": "dw_openpose_image_processor",
|
||||
"mediapipe": "mediapipe_face_processor",
|
||||
"pidi": "pidi_image_processor",
|
||||
"zoe": "zoe_depth_image_processor",
|
||||
"color": "color_map_image_processor",
|
||||
}
|
||||
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
if k in model_name:
|
||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
||||
return None
|
||||
|
||||
|
||||
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||
return MainModelDefaultSettings(width=512, height=512)
|
||||
elif model_base is BaseModelType.StableDiffusionXL:
|
||||
return MainModelDefaultSettings(width=1024, height=1024)
|
||||
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||
return None
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
@@ -4,121 +4,75 @@ Abstract base class and implementation for recursive directory search for models
|
||||
|
||||
Example usage:
|
||||
```
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
||||
@dataclass
|
||||
class SearchStats:
|
||||
"""Statistics about the search.
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
items_scanned: int = 0
|
||||
models_found: int = 0
|
||||
models_filtered: int = 0
|
||||
|
||||
|
||||
class ModelSearchBase(ABC, BaseModel):
|
||||
Attributes:
|
||||
items_scanned: number of items scanned
|
||||
models_found: number of models found
|
||||
models_filtered: number of models that passed the filter
|
||||
"""
|
||||
Abstract directory traversal model search class
|
||||
|
||||
items_scanned = 0
|
||||
models_found = 0
|
||||
models_filtered = 0
|
||||
|
||||
|
||||
class ModelSearch:
|
||||
"""Searches a directory tree for models, using a callback to filter the results.
|
||||
|
||||
Usage:
|
||||
search = ModelSearchBase(
|
||||
on_search_started = search_started_callback,
|
||||
on_search_completed = search_completed_callback,
|
||||
on_model_found = model_found_callback,
|
||||
)
|
||||
models_found = search.search('/path/to/directory')
|
||||
search = ModelSearch()
|
||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
def __init__(
|
||||
self,
|
||||
on_search_started: Optional[Callable[[Path], None]] = None,
|
||||
on_model_found: Optional[Callable[[Path], bool]] = None,
|
||||
on_search_completed: Optional[Callable[[set[Path]], None]] = None,
|
||||
) -> None:
|
||||
"""Create a new ModelSearch object.
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def search_started(self) -> None:
|
||||
Args:
|
||||
on_search_started: callback to be invoked when the search starts
|
||||
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
|
||||
should be included in the results.
|
||||
on_search_completed: callback to be invoked when the search is completed
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the root search directory to the Callable `on_search_started`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_found(self, model: Path) -> None:
|
||||
"""
|
||||
Called when a model is found during search.
|
||||
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
|
||||
Passes the model's Path to the Callable `on_model_found`.
|
||||
This Callable receives the path to the model and returns a boolean
|
||||
to indicate whether the model should be returned in the search
|
||||
results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_completed(self) -> None:
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
"""
|
||||
Recursively search for models in `directory` and return a set of model paths.
|
||||
|
||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||
Callables will be invoked during the search.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
Implementation of ModelSearch with callbacks.
|
||||
Usage:
|
||||
search = ModelSearch()
|
||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default_factory=set)
|
||||
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
self.stats = SearchStats()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
self.on_search_started = on_search_started
|
||||
self.on_model_found = on_model_found
|
||||
self.on_search_completed = on_search_completed
|
||||
self.models_found: set[Path] = set()
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
@@ -135,17 +89,17 @@ class ModelSearch(ModelSearchBase):
|
||||
if self.on_search_completed is not None:
|
||||
self.on_search_completed(self.models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
def search(self, directory: Path) -> set[Path]:
|
||||
self._directory = Path(directory)
|
||||
if not self._directory.is_absolute():
|
||||
self._directory = self.config.models_path / self._directory
|
||||
self._directory = self._directory.resolve()
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(self._directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||
def _walk_directory(self, path: Path, max_depth: int = 20) -> None:
|
||||
"""Recursively walk the directory tree, looking for models."""
|
||||
absolute_path = Path(path)
|
||||
if (
|
||||
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
||||
|
||||
@@ -455,15 +455,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
if callback is not None:
|
||||
|
||||
@@ -44,14 +44,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
@@ -80,10 +72,6 @@ class ConditioningData:
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
@@ -244,19 +243,6 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
@@ -506,64 +492,3 @@ class InvokeAIDiffuserComponent:
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct is not None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, :, 0 : int(width / 2)],
|
||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct is not None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, 0 : int(height / 2)],
|
||||
y_flipped[:, :, int(height / 2) : int(height)],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
@@ -26,7 +27,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||
try:
|
||||
|
||||
@@ -608,8 +608,9 @@ def main() -> None:
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
if not config.models_path.exists():
|
||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
|
||||
@@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
|
||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
||||
}, []);
|
||||
|
||||
return props.children;
|
||||
|
||||
@@ -1,150 +1,3 @@
|
||||
# Invoke UI
|
||||
|
||||
<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=3 orderedList=false} -->
|
||||
|
||||
<!-- code_chunk_output -->
|
||||
|
||||
- [Dev environment](#dev-environment)
|
||||
- [Setup](#setup)
|
||||
- [Package scripts](#package-scripts)
|
||||
- [Type generation](#type-generation)
|
||||
- [Localization](#localization)
|
||||
- [VSCode](#vscode)
|
||||
- [Contributing](#contributing)
|
||||
- [Check in before investing your time](#check-in-before-investing-your-time)
|
||||
- [Commit format](#commit-format)
|
||||
- [Submitting a PR](#submitting-a-pr)
|
||||
- [Other docs](#other-docs)
|
||||
|
||||
<!-- /code_chunk_output -->
|
||||
|
||||
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
|
||||
|
||||
## Dev environment
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install [node] and [pnpm].
|
||||
1. Run `pnpm i` to install all packages.
|
||||
|
||||
#### Run in dev mode
|
||||
|
||||
1. From `invokeai/frontend/web/`, run `pnpm dev`.
|
||||
1. From repo root, run `python scripts/invokeai-web.py`.
|
||||
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
|
||||
|
||||
### Package scripts
|
||||
|
||||
- `dev`: run the frontend in dev mode, enabling hot reloading
|
||||
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
|
||||
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
|
||||
- `lint:madge`: check frontend for circular dependencies
|
||||
- `lint:eslint`: check frontend for code quality
|
||||
- `lint:prettier`: check frontend for code formatting
|
||||
- `lint:tsc`: check frontend for type issues
|
||||
- `lint`: run all checks concurrently
|
||||
- `fix`: run `eslint` and `prettier`, fixing fixable issues
|
||||
|
||||
### Type generation
|
||||
|
||||
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
|
||||
|
||||
The generated types are committed to the repo in [schema.ts].
|
||||
|
||||
```sh
|
||||
# from the repo root, start the server
|
||||
python scripts/invokeai-web.py
|
||||
# from invokeai/frontend/web/, run the script
|
||||
pnpm typegen
|
||||
```
|
||||
|
||||
### Localization
|
||||
|
||||
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
|
||||
|
||||
Only the English source strings should be changed on this repo.
|
||||
|
||||
### VSCode
|
||||
|
||||
#### Example debugger config
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"name": "Invoke UI",
|
||||
"url": "http://localhost:5173",
|
||||
"webRoot": "${workspaceFolder}/invokeai/frontend/web",
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
#### Remote dev
|
||||
|
||||
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
|
||||
|
||||
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
|
||||
|
||||
```sh
|
||||
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
|
||||
```
|
||||
|
||||
## Contributing Guidelines
|
||||
|
||||
Thanks for your interest in contributing to the Invoke Web UI!
|
||||
|
||||
Please follow these guidelines when contributing.
|
||||
|
||||
### Check in before investing your time
|
||||
|
||||
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
|
||||
|
||||
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
|
||||
|
||||
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy chat.
|
||||
|
||||
### Code conventions
|
||||
|
||||
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
|
||||
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
|
||||
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
|
||||
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
|
||||
- Please add comments describing the "why", not the "how" (unless it is really arcane).
|
||||
|
||||
### Commit format
|
||||
|
||||
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
|
||||
|
||||
- `chore(ui): bump deps`
|
||||
- `chore(ui): lint`
|
||||
- `feat(ui): add some cool new feature`
|
||||
- `fix(ui): fix some bug`
|
||||
|
||||
### Submitting a PR
|
||||
|
||||
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
|
||||
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
|
||||
- Fill out the PR form when creating the PR.
|
||||
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
|
||||
- If a section isn't relevant, delete it. There are no UI tests at this time.
|
||||
|
||||
## Other docs
|
||||
|
||||
- [Workflows - Design and Implementation]
|
||||
- [State Management]
|
||||
|
||||
[node]: https://nodejs.org/en/download/
|
||||
[pnpm]: https://github.com/pnpm/pnpm
|
||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||
[i18next]: https://github.com/i18next/react-i18next
|
||||
[Weblate]: https://hosted.weblate.org/engage/invokeai/
|
||||
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
|
||||
[Type generation]: #type-generation
|
||||
[schema.ts]: ../src/services/api/schema.ts
|
||||
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
|
||||
[Workflows - Design and Implementation]: ./docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md
|
||||
[State Management]: ./docs/STATE_MGMT.md
|
||||
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/OVERVIEW/>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
88
invokeai/frontend/web/scripts/clean_translations.py
Normal file
88
invokeai/frontend/web/scripts/clean_translations.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Cleans translations by removing unused keys
|
||||
# Usage: python clean_translations.py
|
||||
# Note: Must be run from invokeai/frontend/web/scripts directory
|
||||
#
|
||||
# After running the script, open `en.json` and check for empty objects (`{}`) and remove them manually.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
RecursiveDict: TypeAlias = dict[str, Union["RecursiveDict", str]]
|
||||
|
||||
|
||||
class TranslationCleaner:
|
||||
file_cache: dict[str, str] = {}
|
||||
|
||||
def _get_keys(self, obj: RecursiveDict, current_path: str = "", keys: list[str] | None = None):
|
||||
if keys is None:
|
||||
keys = []
|
||||
for key in obj:
|
||||
new_path = f"{current_path}.{key}" if current_path else key
|
||||
next_ = obj[key]
|
||||
if isinstance(next_, dict):
|
||||
self._get_keys(next_, new_path, keys)
|
||||
elif "_" in key:
|
||||
# This typically means its a pluralized key
|
||||
continue
|
||||
else:
|
||||
keys.append(new_path)
|
||||
return keys
|
||||
|
||||
def _search_codebase(self, key: str):
|
||||
for root, _dirs, files in os.walk("../src"):
|
||||
for file in files:
|
||||
if file.endswith(".ts") or file.endswith(".tsx"):
|
||||
full_path = os.path.join(root, file)
|
||||
if full_path in self.file_cache:
|
||||
content = self.file_cache[full_path]
|
||||
else:
|
||||
with open(full_path, "r") as f:
|
||||
content = f.read()
|
||||
self.file_cache[full_path] = content
|
||||
|
||||
# match the whole key, surrounding by quotes
|
||||
if re.search(r"['\"`]" + re.escape(key) + r"['\"`]", self.file_cache[full_path]):
|
||||
return True
|
||||
# math the stem of the key, with quotes at the end
|
||||
if re.search(re.escape(key.split(".")[-1]) + r"['\"`]", self.file_cache[full_path]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _remove_key(self, obj: RecursiveDict, key: str):
|
||||
path = key.split(".")
|
||||
last_key = path[-1]
|
||||
for k in path[:-1]:
|
||||
obj = obj[k]
|
||||
del obj[last_key]
|
||||
|
||||
def clean(self, obj: RecursiveDict) -> RecursiveDict:
|
||||
keys = self._get_keys(obj)
|
||||
pbar = tqdm(keys, desc="Checking keys")
|
||||
for key in pbar:
|
||||
if not self._search_codebase(key):
|
||||
self._remove_key(obj, key)
|
||||
return obj
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
with open("../public/locales/en.json", "r") as f:
|
||||
data = json.load(f)
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(
|
||||
"Unable to find en.json file - must be run from invokeai/frontend/web/scripts directory"
|
||||
) from e
|
||||
|
||||
cleaner = TranslationCleaner()
|
||||
cleaned_data = cleaner.clean(data)
|
||||
|
||||
with open("../public/locales/en.json", "w") as f:
|
||||
json.dump(cleaned_data, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import {
|
||||
controlAdapterModelCleared,
|
||||
selectAllControlNets,
|
||||
selectAllIPAdapters,
|
||||
selectAllT2IAdapters,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
@@ -12,212 +12,162 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
import type { TypeGuardFor } from 'services/api/types';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
|
||||
|
||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelChanged(result.data, currentModel));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (!isCurrentModelAvailable) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// VAEs loaded, need to reset the VAE is it's no longer available
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentVae = getState().generation.vae;
|
||||
|
||||
if (currentVae === null) {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
|
||||
|
||||
if (isCurrentVAEAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
|
||||
|
||||
if (!firstModel) {
|
||||
// No custom VAEs loaded at all; use the default
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// LoRA models loaded - need to remove missing LoRAs from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
|
||||
|
||||
const loras = getState().lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRemoved(id));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
|
||||
|
||||
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
||||
effect: async (action) => {
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
|
||||
handleMainModels(models, state, dispatch, log);
|
||||
handleRefinerModels(models, state, dispatch, log);
|
||||
handleVAEModels(models, state, dispatch, log);
|
||||
handleLoRAModels(models, state, dispatch, log);
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
type ModelHandler = (
|
||||
models: AnyModelConfig[],
|
||||
state: RootState,
|
||||
dispatch: AppDispatch,
|
||||
log: Logger<JSONObject>
|
||||
) => undefined;
|
||||
|
||||
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentModel = state.generation.model;
|
||||
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||
if (mainModels.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
||||
|
||||
if (isCurrentMainModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelChanged(result.data, currentModel));
|
||||
};
|
||||
|
||||
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const currentRefinerModel = state.sdxl.refinerModel;
|
||||
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const isCurrentRefinerModelAvailable = currentRefinerModel
|
||||
? refinerModels.some((m) => m.key === currentRefinerModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentRefinerModelAvailable) {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentVae = state.generation.vae;
|
||||
|
||||
if (currentVae === null) {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
return;
|
||||
}
|
||||
const vaeModels = models.filter(isVAEModelConfig);
|
||||
|
||||
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
|
||||
|
||||
if (isCurrentVAEAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = vaeModels[0];
|
||||
|
||||
if (!firstModel) {
|
||||
// No custom VAEs loaded at all; use the default
|
||||
dispatch(vaeSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zParameterVAEModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaeSelected(result.data));
|
||||
};
|
||||
|
||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const loras = state.lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRemoved(id));
|
||||
});
|
||||
};
|
||||
|
||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,26 +1,30 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
heightChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterHeight,
|
||||
isParameterPrecision,
|
||||
isParameterScheduler,
|
||||
isParameterSteps,
|
||||
isParameterWidth,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -34,63 +38,80 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
return;
|
||||
}
|
||||
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||
const data = await request.unwrap();
|
||||
request.unsubscribe();
|
||||
const models = modelConfigsAdapterSelectors.selectAll(data);
|
||||
|
||||
if (!modelConfig || !modelConfig.default_settings) {
|
||||
const modelConfig = models.find((model) => model.key === currentModel.key);
|
||||
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
|
||||
modelConfig.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const vaeModel = models.find((model) => model.key === vae);
|
||||
const result = zParameterVAEModel.safeParse(vaeModel);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
if (width) {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(widthChanged(width));
|
||||
}
|
||||
}
|
||||
|
||||
if (height) {
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(heightChanged(height));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
|
||||
@@ -29,6 +30,11 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
|
||||
|
||||
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
||||
if ($isFirstConnection.get()) {
|
||||
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
|
||||
// unsubscribe.
|
||||
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||
request.unsubscribe();
|
||||
|
||||
$isFirstConnection.set(false);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallError,
|
||||
@@ -63,4 +64,21 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallCancelled,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -8,14 +8,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action) => {
|
||||
const { base_model, model_name, model_type, submodel } = action.payload.data;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
let message = `Model load started: ${base_model}/${model_type}/${model_name}`;
|
||||
|
||||
if (submodel) {
|
||||
message = message.concat(`/${submodel}`);
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
|
||||
const message = `Model load started: ${name} (${extras.join(', ')})`;
|
||||
|
||||
log.debug(action.payload, message);
|
||||
},
|
||||
});
|
||||
@@ -23,14 +25,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
effect: (action) => {
|
||||
const { base_model, model_name, model_type, submodel } = action.payload.data;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
let message = `Model load complete: ${base_model}/${model_type}/${model_name}`;
|
||||
|
||||
if (submodel) {
|
||||
message = message.concat(`/${submodel}`);
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
|
||||
const message = `Model load complete: ${name} (${extras.join(', ')})`;
|
||||
|
||||
log.debug(action.payload, message);
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { groupBy, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
@@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelEntities) {
|
||||
if (!modelConfigs) {
|
||||
return [];
|
||||
}
|
||||
const modelEntitiesArray = map(modelEntities.entities);
|
||||
const groupedModels = groupBy(modelEntitiesArray, 'base');
|
||||
const groupedModels = groupBy(modelConfigs, 'base');
|
||||
const _options = reduce(
|
||||
groupedModels,
|
||||
(acc, val, label) => {
|
||||
@@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
);
|
||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||
return _options;
|
||||
}, [getIsDisabled, modelEntities, base_model]);
|
||||
}, [getIsDisabled, modelConfigs, base_model]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
@@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { map } from 'lodash-es';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
optionsFilter?: (model: T) => boolean;
|
||||
@@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
|
||||
|
||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const options = useMemo<ComboboxOption[]>(() => {
|
||||
if (!modelEntities) {
|
||||
return [];
|
||||
}
|
||||
return map(modelEntities.entities)
|
||||
.filter(optionsFilter)
|
||||
.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||
return modelConfigs.filter(optionsFilter).map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelConfigs]);
|
||||
|
||||
const value = useMemo(
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||
@@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
modelFilter?: (model: T) => boolean;
|
||||
isModelDisabled?: (model: T) => boolean;
|
||||
@@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
|
||||
const isModelDisabledDefault = () => false;
|
||||
|
||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange,
|
||||
@@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
|
||||
const items: Item[] = useMemo(
|
||||
() =>
|
||||
data
|
||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
}))
|
||||
: EMPTY_ARRAY,
|
||||
[data, isModelDisabled, modelFilter]
|
||||
modelConfigs.filter(modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
})),
|
||||
[modelConfigs, isModelDisabled, modelFilter]
|
||||
);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(item: Item | null) => {
|
||||
if (!item || !data) {
|
||||
if (!item || !modelConfigs) {
|
||||
return;
|
||||
}
|
||||
const model = data.entities[item.value];
|
||||
const model = modelConfigs.find((m) => m.key === item.value);
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[data, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
|
||||
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
@@ -14,12 +15,13 @@ type Props = {
|
||||
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||
dispatch(controlAdapterAutoConfigToggled({ id }));
|
||||
}, [id, dispatch]);
|
||||
dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
|
||||
}, [id, dispatch, modelConfig]);
|
||||
|
||||
if (isNil(shouldAutoConfig)) {
|
||||
return null;
|
||||
|
||||
@@ -3,10 +3,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -17,21 +16,21 @@ type ParamControlAdapterModelProps = {
|
||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const controlAdapterType = useControlAdapterType(id);
|
||||
const model = useControlAdapterModel(id);
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!model) {
|
||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterModelChanged({
|
||||
id,
|
||||
model: getModelKeyAndBase(model),
|
||||
modelConfig,
|
||||
})
|
||||
);
|
||||
},
|
||||
@@ -39,12 +38,12 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, model]
|
||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, modelConfig]
|
||||
);
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange: _onChange,
|
||||
@@ -53,7 +52,13 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||
<CustomSelect
|
||||
key={items.length}
|
||||
selectedItem={selectedItem}
|
||||
placeholder={placeholder}
|
||||
items={items}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const models = useControlAdapterModels(type);
|
||||
const [models] = useControlAdapterModels(type);
|
||||
|
||||
const firstModel = useMemo(() => {
|
||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
||||
// prefer to use a model that matches the base model
|
||||
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
|
||||
|
||||
@@ -28,6 +29,26 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
(type === 'controlnet' || type === 't2i_adapter') &&
|
||||
(firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter')
|
||||
) {
|
||||
const defaultPreprocessor = firstModel.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
overrides: {
|
||||
model: firstModel,
|
||||
processorType,
|
||||
processorNode,
|
||||
},
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterAdded({
|
||||
type,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
@@ -5,18 +6,22 @@ import {
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlAdapterModel = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(
|
||||
selectControlAdaptersSlice,
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key
|
||||
),
|
||||
[id]
|
||||
);
|
||||
|
||||
const model = useAppSelector(selector);
|
||||
const key = useAppSelector(selector);
|
||||
|
||||
return model;
|
||||
const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig);
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsQuery;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsQuery;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsQuery;
|
||||
}
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
||||
@@ -1,31 +1,10 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
controlNetModelsAdapterSelectors,
|
||||
ipAdapterModelsAdapterSelectors,
|
||||
t2iAdapterModelsAdapterSelectors,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
||||
const controlNetModels = useMemo(
|
||||
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
|
||||
[controlNetModelsData]
|
||||
);
|
||||
|
||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
||||
const t2iAdapterModels = useMemo(
|
||||
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
|
||||
[t2iAdapterModelsData]
|
||||
);
|
||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
||||
const ipAdapterModels = useMemo(
|
||||
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
|
||||
[ipAdapterModelsData]
|
||||
);
|
||||
export const useControlAdapterModels = (type: ControlAdapterType) => {
|
||||
const controlNetModels = useControlNetModels();
|
||||
const t2iAdapterModels = useT2IAdapterModels();
|
||||
const ipAdapterModels = useIPAdapterModels();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModels;
|
||||
@@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModels;
|
||||
}
|
||||
return [];
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
||||
|
||||
@@ -93,7 +93,6 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
type: 'depth_anything_image_processor',
|
||||
model_size: 'small',
|
||||
resolution: 512,
|
||||
offload: false,
|
||||
},
|
||||
},
|
||||
hed_image_processor: {
|
||||
@@ -253,23 +252,3 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||
[key: string]: ControlAdapterProcessorType;
|
||||
} = {
|
||||
canny: 'canny_image_processor',
|
||||
mlsd: 'mlsd_image_processor',
|
||||
depth: 'depth_anything_image_processor',
|
||||
bae: 'normalbae_image_processor',
|
||||
sketch: 'pidi_image_processor',
|
||||
scribble: 'lineart_image_processor',
|
||||
lineart: 'lineart_image_processor',
|
||||
lineart_anime: 'lineart_anime_image_processor',
|
||||
softedge: 'hed_image_processor',
|
||||
shuffle: 'content_shuffle_image_processor',
|
||||
openpose: 'dw_openpose_image_processor',
|
||||
mediapipe: 'mediapipe_face_processor',
|
||||
pidi: 'pidi_image_processor',
|
||||
zoe: 'zoe_depth_image_processor',
|
||||
color: 'color_map_image_processor',
|
||||
};
|
||||
|
||||
@@ -3,20 +3,15 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import type {
|
||||
ParameterControlNetModel,
|
||||
ParameterIPAdapterModel,
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
@@ -194,15 +189,17 @@ export const controlAdaptersSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
||||
}>
|
||||
) => {
|
||||
const { id, model } = action.payload;
|
||||
const { id, modelConfig } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn) {
|
||||
return;
|
||||
}
|
||||
|
||||
const model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
caAdapter.updateOne(state, { id, changes: { model } });
|
||||
return;
|
||||
@@ -215,24 +212,14 @@ export const controlAdaptersSlice = createSlice({
|
||||
|
||||
update.changes.processedControlImage = null;
|
||||
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (model.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
if (modelConfig.type === 'ip_adapter') {
|
||||
// should never happen...
|
||||
return;
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
const processor = buildControlAdapterProcessor(modelConfig);
|
||||
update.changes.processorType = processor.processorType;
|
||||
update.changes.processorNode = processor.processorNode;
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
},
|
||||
@@ -324,39 +311,23 @@ export const controlAdaptersSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
||||
}>
|
||||
) => {
|
||||
const { id } = action.payload;
|
||||
const { id, modelConfig } = action.payload;
|
||||
const cn = selectControlAdapterById(state, id);
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn)) {
|
||||
if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
|
||||
id,
|
||||
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
||||
};
|
||||
|
||||
if (update.changes.shouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||
if (cn.model?.key.includes(modelSubstring)) {
|
||||
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
update.changes.processorType = processorType;
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||
.default as RequiredControlAdapterProcessorNode;
|
||||
} else {
|
||||
update.changes.processorType = 'none';
|
||||
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||
}
|
||||
if (update.changes.shouldAutoConfig && modelConfig) {
|
||||
const processor = buildControlAdapterProcessor(modelConfig);
|
||||
update.changes.processorType = processor.processorType;
|
||||
update.changes.processorNode = processor.processorNode;
|
||||
}
|
||||
|
||||
caAdapter.updateOne(state, update);
|
||||
@@ -367,6 +338,21 @@ export const controlAdaptersSlice = createSlice({
|
||||
pendingControlImagesCleared: (state) => {
|
||||
state.pendingControlImages = [];
|
||||
},
|
||||
ipAdaptersReset: (state) => {
|
||||
selectAllIPAdapters(state).forEach((ca) => {
|
||||
caAdapter.removeOne(state, ca.id);
|
||||
});
|
||||
},
|
||||
controlNetsReset: (state) => {
|
||||
selectAllControlNets(state).forEach((ca) => {
|
||||
caAdapter.removeOne(state, ca.id);
|
||||
});
|
||||
},
|
||||
t2iAdaptersReset: (state) => {
|
||||
selectAllT2IAdapters(state).forEach((ca) => {
|
||||
caAdapter.removeOne(state, ca.id);
|
||||
});
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(controlAdapterImageProcessed, (state, action) => {
|
||||
@@ -405,6 +391,9 @@ export const {
|
||||
controlAdapterAutoConfigToggled,
|
||||
pendingControlImagesCleared,
|
||||
controlAdapterModelCleared,
|
||||
ipAdaptersReset,
|
||||
controlNetsReset,
|
||||
t2iAdaptersReset,
|
||||
} = controlAdaptersSlice.actions;
|
||||
|
||||
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
import type { ControlAdapterProcessorType, zControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
import type { z } from 'zod';
|
||||
|
||||
describe('Control Adapter Types', () => {
|
||||
test('ControlAdapterProcessorType', () =>
|
||||
assert<Equals<ControlAdapterProcessorType, z.infer<typeof zControlAdapterProcessorType>>>());
|
||||
});
|
||||
@@ -47,6 +47,25 @@ export type ControlAdapterProcessorNode =
|
||||
* Any ControlNet processor type
|
||||
*/
|
||||
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
|
||||
export const zControlAdapterProcessorType = z.enum([
|
||||
'canny_image_processor',
|
||||
'color_map_image_processor',
|
||||
'content_shuffle_image_processor',
|
||||
'depth_anything_image_processor',
|
||||
'hed_image_processor',
|
||||
'lineart_anime_image_processor',
|
||||
'lineart_image_processor',
|
||||
'mediapipe_face_processor',
|
||||
'midas_depth_image_processor',
|
||||
'mlsd_image_processor',
|
||||
'normalbae_image_processor',
|
||||
'dw_openpose_image_processor',
|
||||
'pidi_image_processor',
|
||||
'zoe_depth_image_processor',
|
||||
'none',
|
||||
]);
|
||||
export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType =>
|
||||
zControlAdapterProcessorType.safeParse(v).success;
|
||||
|
||||
/**
|
||||
* The Canny processor node, with parameters flagged as required
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
||||
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
|
||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
return { processorType, processorNode };
|
||||
};
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
autoAssignBoardOnClickChanged,
|
||||
setGalleryImageMinimumWidth,
|
||||
shouldAutoSwitchChanged,
|
||||
@@ -36,6 +37,7 @@ const GallerySettingsPopover = () => {
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
|
||||
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = useCallback(
|
||||
(v: number) => {
|
||||
@@ -56,6 +58,11 @@ const GallerySettingsPopover = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeAlwaysShowImageSizeBadgeChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
@@ -88,6 +95,10 @@ const GallerySettingsPopover = () => {
|
||||
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
|
||||
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
|
||||
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={handleChangeAlwaysShowImageSizeBadgeChanged} />
|
||||
</FormControl>
|
||||
</FormControlGroup>
|
||||
<BoardAutoAddSelect />
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Text, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -22,6 +22,16 @@ const imageIconStyleOverrides: SystemStyleObject = {
|
||||
bottom: 2,
|
||||
top: 'auto',
|
||||
};
|
||||
const boxSx: SystemStyleObject = {
|
||||
containerType: 'inline-size',
|
||||
};
|
||||
|
||||
const badgeSx: SystemStyleObject = {
|
||||
'@container (max-width: 80px)': {
|
||||
'&': { display: 'none' },
|
||||
},
|
||||
};
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageName: string;
|
||||
index: number;
|
||||
@@ -34,6 +44,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
const shift = useShiftModifier();
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
@@ -121,7 +132,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId}>
|
||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
|
||||
<Flex
|
||||
ref={imageContainerRef}
|
||||
userSelect="none"
|
||||
@@ -145,6 +156,23 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
<>
|
||||
{(isHovered || alwaysShowImageSizeBadge) && (
|
||||
<Text
|
||||
position="absolute"
|
||||
background="base.900"
|
||||
color="base.50"
|
||||
fontSize="sm"
|
||||
fontWeight="semibold"
|
||||
bottom={0}
|
||||
left={0}
|
||||
opacity={0.7}
|
||||
px={2}
|
||||
lineHeight={1.25}
|
||||
borderTopEndRadius="base"
|
||||
borderBottomStartRadius="base"
|
||||
sx={badgeSx}
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
)}
|
||||
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
||||
|
||||
{isHovered && shift && (
|
||||
|
||||
@@ -15,6 +15,7 @@ const initialGalleryState: GalleryState = {
|
||||
autoAssignBoardOnClick: true,
|
||||
autoAddBoardId: 'none',
|
||||
galleryImageMinimumWidth: 90,
|
||||
alwaysShowImageSizeBadge: false,
|
||||
selectedBoardId: 'none',
|
||||
galleryView: 'images',
|
||||
boardSearchText: '',
|
||||
@@ -71,6 +72,9 @@ export const gallerySlice = createSlice({
|
||||
state.limit += IMAGE_LIMIT;
|
||||
}
|
||||
},
|
||||
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.alwaysShowImageSizeBadge = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@@ -107,6 +111,7 @@ export const {
|
||||
selectionChanged,
|
||||
boardSearchTextChanged,
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
|
||||
@@ -19,4 +19,5 @@ export type GalleryState = {
|
||||
boardSearchText: string;
|
||||
offset: number;
|
||||
limit: number;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
};
|
||||
|
||||
@@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
|
||||
const LoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
@@ -37,7 +37,7 @@ const LoRASelect = () => {
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
modelConfigs,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
@@ -31,7 +32,7 @@ export const loraSlice = createSlice({
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||
const model = getModelKeyAndBase(action.payload);
|
||||
const model = zModelIdentifierField.parse(action.payload);
|
||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||
@@ -57,10 +58,12 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
lorasReset: () => cloneDeep(initialLoraState),
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled } = loraSlice.actions;
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
|
||||
loraSlice.actions;
|
||||
|
||||
export const selectLoraSlice = (state: RootState) => state.lora;
|
||||
|
||||
|
||||
@@ -13,13 +13,13 @@ import type {
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { validators } from 'features/metadata/util/validators';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { t } from 'i18next';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
import { recallers } from './recallers';
|
||||
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.key);
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
@@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
|
||||
}
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
};
|
||||
|
||||
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
|
||||
key: modelConfig.key,
|
||||
base: modelConfig.base,
|
||||
});
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
@@ -13,12 +13,7 @@ import type {
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zModelIdentifierWithBase,
|
||||
zT2IAdapterField,
|
||||
} from 'features/nodes/types/common';
|
||||
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
@@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
||||
const model = await getProperty(metadata, 'model', undefined);
|
||||
const key = await getModelKey(model, 'main');
|
||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
|
||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||
const key = await getModelKey(refiner_model, 'main');
|
||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
||||
const vae = await getProperty(metadata, 'vae', undefined);
|
||||
const key = await getModelKey(vae, 'vae');
|
||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
return {
|
||||
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
||||
model: zModelIdentifierField.parse(loraModelConfig),
|
||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||
isEnabled: true,
|
||||
};
|
||||
@@ -230,43 +225,48 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
||||
const control_model = await getProperty(metadataItem, 'control_model');
|
||||
const key = await getModelKey(control_model, 'controlnet');
|
||||
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
|
||||
const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const image = zControlField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const processedImage = zControlField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||
const control_weight = zControlField.shape.control_weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'control_weight'));
|
||||
.parse(await getProperty(metadataItem, 'control_weight'));
|
||||
const begin_step_percent = zControlField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zControlField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
const control_mode = zControlField.shape.control_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'control_mode'));
|
||||
.parse(await getProperty(metadataItem, 'control_mode'));
|
||||
const resize_mode = zControlField.shape.resize_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
.parse(await getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel);
|
||||
|
||||
const controlNet: ControlNetConfigMetadata = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
model: zModelIdentifierField.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
||||
controlMode: control_mode ?? initialControlNet.controlMode,
|
||||
resizeMode: resize_mode ?? initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name ?? null,
|
||||
processedControlImage: image?.image_name ?? null,
|
||||
processedControlImage: processedImage?.image_name ?? null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
@@ -290,34 +290,43 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
||||
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
||||
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
|
||||
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
||||
const image = zT2IAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const processedImage = zT2IAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||
const weight = zT2IAdapterField.shape.weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zT2IAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
const resize_mode = zT2IAdapterField.shape.resize_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||
.parse(await getProperty(metadataItem, 'resize_mode'));
|
||||
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel);
|
||||
|
||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
model: zModelIdentifierField.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name ?? null,
|
||||
processedControlImage: image?.image_name ?? null,
|
||||
processedControlImage: processedImage?.image_name ?? null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
@@ -341,22 +350,28 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
||||
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
|
||||
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
||||
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
||||
const image = zIPAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const weight = zIPAdapterField.shape.weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zIPAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
|
||||
const ipAdapter: IPAdapterConfigMetadata = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
controlAdapterRecalled,
|
||||
controlNetsReset,
|
||||
ipAdaptersReset,
|
||||
t2iAdaptersReset,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
IPAdapterConfigMetadata,
|
||||
@@ -166,7 +171,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
};
|
||||
|
||||
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||
if (!loras.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(lorasReset());
|
||||
loras.forEach((lora) => {
|
||||
dispatch(loraRecalled(lora));
|
||||
});
|
||||
@@ -177,7 +186,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
|
||||
};
|
||||
|
||||
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||
if (!controlNets.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(controlNetsReset());
|
||||
controlNets.forEach((controlNet) => {
|
||||
dispatch(controlAdapterRecalled(controlNet));
|
||||
});
|
||||
@@ -188,7 +201,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
|
||||
};
|
||||
|
||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||
if (!t2iAdapters.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(t2iAdaptersReset());
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
});
|
||||
@@ -199,7 +216,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
|
||||
};
|
||||
|
||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||
if (!ipAdapters.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(ipAdaptersReset());
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
dispatch(controlAdapterRecalled(ipAdapter));
|
||||
});
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
import { useSyncModels } from './useSyncModels';
|
||||
|
||||
export const SyncModelsButton = memo((props: Omit<ButtonProps, 'aria-label'>) => {
|
||||
const { t } = useTranslation();
|
||||
const { syncModels, isLoading } = useSyncModels();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
if (!isSyncModelEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
leftIcon={<PiArrowsClockwiseBold />}
|
||||
isLoading={isLoading}
|
||||
onClick={syncModels}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
{...props}
|
||||
>
|
||||
{t('modelManager.syncModels')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
SyncModelsButton.displayName = 'SyncModelsButton';
|
||||
@@ -0,0 +1,23 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
||||
modelKey ?? skipToken,
|
||||
isControlNetOrT2IAdapterModelConfig
|
||||
);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
preprocessor: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.preprocessor),
|
||||
value: modelConfig?.default_settings?.preprocessor || 'none',
|
||||
},
|
||||
};
|
||||
}, [modelConfig?.default_settings]);
|
||||
|
||||
return { defaultSettingsDefaults, isLoading };
|
||||
};
|
||||
@@ -0,0 +1,85 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
initialWidth: width.initial,
|
||||
initialHeight: height.initial,
|
||||
};
|
||||
});
|
||||
|
||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
||||
|
||||
const {
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
initialWidth,
|
||||
initialHeight,
|
||||
} = useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.vae),
|
||||
value: modelConfig?.default_settings?.vae || 'default',
|
||||
},
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.vae_precision),
|
||||
value: modelConfig?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.scheduler),
|
||||
value: modelConfig?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.steps),
|
||||
value: modelConfig?.default_settings?.steps || initialSteps,
|
||||
},
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.cfg_scale),
|
||||
value: modelConfig?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
|
||||
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
width: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.width),
|
||||
value: modelConfig?.default_settings?.width || initialWidth,
|
||||
},
|
||||
height: {
|
||||
isEnabled: !isNil(modelConfig?.default_settings?.height),
|
||||
value: modelConfig?.default_settings?.height || initialHeight,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
modelConfig,
|
||||
initialVaePrecision,
|
||||
initialScheduler,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialWidth,
|
||||
initialHeight,
|
||||
]);
|
||||
|
||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
||||
};
|
||||
@@ -1,13 +1,17 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
selectedModelKey: string | null;
|
||||
selectedModelMode: 'edit' | 'view';
|
||||
searchTerm: string;
|
||||
filteredModelType: string | null;
|
||||
filteredModelType: FilterableModelType | null;
|
||||
scanPath: string | undefined;
|
||||
};
|
||||
|
||||
const initialModelManagerState: ModelManagerState = {
|
||||
@@ -16,6 +20,7 @@ const initialModelManagerState: ModelManagerState = {
|
||||
selectedModelMode: 'view',
|
||||
filteredModelType: null,
|
||||
searchTerm: '',
|
||||
scanPath: undefined,
|
||||
};
|
||||
|
||||
export const modelManagerV2Slice = createSlice({
|
||||
@@ -33,13 +38,16 @@ export const modelManagerV2Slice = createSlice({
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
|
||||
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
||||
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
|
||||
state.filteredModelType = action.payload;
|
||||
},
|
||||
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.scanPath = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } =
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||
modelManagerV2Slice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -54,5 +62,5 @@ export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
|
||||
name: modelManagerV2Slice.name,
|
||||
initialState: initialModelManagerState,
|
||||
migrate: migrateModelManagerState,
|
||||
persistDenylist: [],
|
||||
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
|
||||
};
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
|
||||
export const HuggingFaceForm = () => {
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||
const [displayResults, setDisplayResults] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleInstallModel = useCallback(
|
||||
(source: string) => {
|
||||
installModel({ source })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, dispatch, t]
|
||||
);
|
||||
|
||||
const getModels = useCallback(async () => {
|
||||
_getHuggingFaceModels(huggingFaceRepo)
|
||||
.unwrap()
|
||||
.then((response) => {
|
||||
if (response.is_diffusers) {
|
||||
handleInstallModel(huggingFaceRepo);
|
||||
setDisplayResults(false);
|
||||
} else if (response.urls?.length === 1 && response.urls[0]) {
|
||||
handleInstallModel(response.urls[0]);
|
||||
setDisplayResults(false);
|
||||
} else {
|
||||
setDisplayResults(true);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
setErrorMessage(error.data.detail || '');
|
||||
});
|
||||
}, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]);
|
||||
|
||||
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setHuggingFaceRepo(e.target.value);
|
||||
setErrorMessage('');
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%" gap={3}>
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
|
||||
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input
|
||||
placeholder={t('modelManager.huggingFacePlaceholder')}
|
||||
value={huggingFaceRepo}
|
||||
onChange={handleSetHuggingFaceRepo}
|
||||
/>
|
||||
<Button
|
||||
onClick={getModels}
|
||||
isLoading={isLoading}
|
||||
isDisabled={huggingFaceRepo.length === 0}
|
||||
size="sm"
|
||||
flexShrink={0}
|
||||
>
|
||||
{t('modelManager.installRepo')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,57 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: string;
|
||||
};
|
||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleInstall = useCallback(() => {
|
||||
installModel({ source: result })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [installModel, result, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
|
||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||
{result}
|
||||
</Text>
|
||||
</Flex>
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,123 @@
|
||||
import {
|
||||
Button,
|
||||
Divider,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
|
||||
|
||||
type HuggingFaceResultsProps = {
|
||||
results: string[];
|
||||
};
|
||||
|
||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.filter((result) => {
|
||||
const modelName = result.split('/').slice(-1)[0];
|
||||
return modelName?.toLowerCase().includes(searchTerm.toLowerCase());
|
||||
});
|
||||
}, [results, searchTerm]);
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setSearchTerm(e.target.value.trim());
|
||||
}, []);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
setSearchTerm('');
|
||||
}, []);
|
||||
|
||||
const handleAddAll = useCallback(() => {
|
||||
for (const result of filteredResults) {
|
||||
installModel({ source: result })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [filteredResults, installModel, dispatch, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Divider />
|
||||
<Flex flexDir="column" gap={3} height="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Heading size="sm">{t('modelManager.availableModels')}</Heading>
|
||||
<Flex alignItems="center" gap={3}>
|
||||
<Button size="sm" onClick={handleAddAll} isDisabled={results.length === 0} flexShrink={0}>
|
||||
{t('modelManager.installAll')}
|
||||
</Button>
|
||||
<InputGroup w={64} size="xs">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm}
|
||||
data-testid="board-search-input"
|
||||
onChange={handleSearch}
|
||||
size="xs"
|
||||
/>
|
||||
|
||||
{searchTerm && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
onClick={clearSearch}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{filteredResults.map((result) => (
|
||||
<HuggingFaceResultItem key={result} result={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
@@ -10,6 +10,7 @@ import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
inplace: boolean;
|
||||
};
|
||||
|
||||
export const InstallModelForm = () => {
|
||||
@@ -20,6 +21,7 @@ export const InstallModelForm = () => {
|
||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||
defaultValues: {
|
||||
location: '',
|
||||
inplace: true,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
@@ -30,7 +32,7 @@ export const InstallModelForm = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
installModel({ source: values.location })
|
||||
installModel({ source: values.location, inplace: values.inplace })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@@ -41,10 +43,10 @@ export const InstallModelForm = () => {
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
reset(undefined, { keepValues: true });
|
||||
})
|
||||
.catch((error) => {
|
||||
reset();
|
||||
reset(undefined, { keepValues: true });
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
@@ -62,16 +64,37 @@ export const InstallModelForm = () => {
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('modelManager.urlOrLocalPath')}</FormLabel>
|
||||
<Flex alignItems="center" gap={3} w="full">
|
||||
<Input placeholder={t('modelManager.simpleModelPlaceholder')} {...register('location')} />
|
||||
<Button
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isDisabled={!formState.dirtyFields.location}
|
||||
isLoading={isLoading}
|
||||
type="submit"
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.install')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<FormHelperText>{t('modelManager.urlOrLocalPathHelper')}</FormHelperText>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<FormControl>
|
||||
<Flex direction="column" w="full">
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input {...register('location')} />
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex gap={4}>
|
||||
<Checkbox {...register('inplace')} />
|
||||
<FormLabel>
|
||||
{t('modelManager.inplaceInstall')} ({t('modelManager.localOnly')})
|
||||
</FormLabel>
|
||||
</Flex>
|
||||
<FormHelperText>{t('modelManager.inplaceInstallDesc')}</FormHelperText>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Button onClick={handleSubmit(onSubmit)} isDisabled={!formState.isDirty} isLoading={isLoading} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
@@ -50,9 +50,9 @@ export const ModelInstallQueue = () => {
|
||||
}, [data]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex flexDir="column" p={3} h="full" gap={3}>
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Heading size="sm">{t('modelManager.installQueue')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
isDisabled={!pruneAvailable}
|
||||
@@ -62,9 +62,9 @@ export const ModelInstallQueue = () => {
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
<Flex flexDir="column-reverse" gap="2" w="full">
|
||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Badge, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Badge } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ModelInstallStatus } from 'services/api/types';
|
||||
@@ -13,13 +13,7 @@ const STATUSES = {
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ModelInstallQueueBadge = ({
|
||||
status,
|
||||
errorReason,
|
||||
}: {
|
||||
status?: ModelInstallStatus;
|
||||
errorReason?: string | null;
|
||||
}) => {
|
||||
const ModelInstallQueueBadge = ({ status }: { status?: ModelInstallStatus }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
@@ -27,9 +21,9 @@ const ModelInstallQueueBadge = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip label={errorReason}>
|
||||
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
|
||||
</Tooltip>
|
||||
<Badge textAlign="center" w="134px" colorScheme={STATUSES[status].colorScheme}>
|
||||
{t(STATUSES[status].translationKey)}
|
||||
</Badge>
|
||||
);
|
||||
};
|
||||
export default memo(ModelInstallQueueBadge);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
@@ -7,7 +7,7 @@ import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||
import type { ModelInstallJob } from 'services/api/types';
|
||||
|
||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||
|
||||
@@ -16,7 +16,7 @@ type ModelListItemProps = {
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
|
||||
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
||||
|
||||
let i = 0;
|
||||
|
||||
@@ -33,18 +33,6 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (installJob.source.type === 'hf') {
|
||||
return installJob.source as HFModelSource;
|
||||
} else if (installJob.source.type === 'local') {
|
||||
return installJob.source as LocalModelSource;
|
||||
} else if (installJob.source.type === 'url') {
|
||||
return installJob.source as URLModelSource;
|
||||
} else {
|
||||
return installJob.source as LocalModelSource;
|
||||
}
|
||||
}, [installJob.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(installJob.id)
|
||||
.unwrap()
|
||||
@@ -72,18 +60,31 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
});
|
||||
}, [deleteImportModel, installJob, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
const sourceLocation = useMemo(() => {
|
||||
switch (installJob.source.type) {
|
||||
case 'hf':
|
||||
return source.repo_id;
|
||||
return installJob.source.repo_id;
|
||||
case 'url':
|
||||
return source.url;
|
||||
return installJob.source.url;
|
||||
case 'local':
|
||||
return source.path.split('\\').slice(-1)[0];
|
||||
return installJob.source.path;
|
||||
default:
|
||||
return '';
|
||||
return t('common.unknown');
|
||||
}
|
||||
}, [source]);
|
||||
}, [installJob.source]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (installJob.source.type) {
|
||||
case 'hf':
|
||||
return installJob.source.repo_id;
|
||||
case 'url':
|
||||
return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
|
||||
case 'local':
|
||||
return installJob.source.path.split('\\').slice(-1)[0] ?? t('common.unknown');
|
||||
default:
|
||||
return t('common.unknown');
|
||||
}
|
||||
}, [installJob.source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
@@ -97,48 +98,67 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
|
||||
return (
|
||||
<Flex gap={3} w="full" alignItems="center">
|
||||
<Tooltip maxW={600} label={<TooltipLabel name={modelName} source={sourceLocation} installJob={installJob} />}>
|
||||
<Flex gap={3} w="full" alignItems="center">
|
||||
<Text w={96} whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
{modelName}
|
||||
</Text>
|
||||
<Progress
|
||||
w="full"
|
||||
flexGrow={1}
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
<ModelInstallQueueBadge status={installJob.status} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
isDisabled={
|
||||
installJob.status !== 'downloading' && installJob.status !== 'waiting' && installJob.status !== 'running'
|
||||
}
|
||||
size="xs"
|
||||
tooltip={t('modelManager.cancel')}
|
||||
aria-label={t('modelManager.cancel')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDeleteModelImport}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type TooltipLabelProps = {
|
||||
installJob: ModelInstallJob;
|
||||
name: string;
|
||||
source: string;
|
||||
};
|
||||
|
||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
if (installJob.status === 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
<Tooltip label={modelName}>
|
||||
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
{modelName}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
<>
|
||||
<Flex gap={3} justifyContent="space-between">
|
||||
<Text fontWeight="semibold">{name}</Text>
|
||||
{progressString && <Text>{progressString}</Text>}
|
||||
</Flex>
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box minW="20px">
|
||||
{(installJob.status === 'downloading' ||
|
||||
installJob.status === 'waiting' ||
|
||||
installJob.status === 'running') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
||||
tooltip={t('modelManager.cancel')}
|
||||
aria-label={t('modelManager.cancel')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDeleteModelImport}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
<Text fontStyle="italic" wordBreak="break-all">
|
||||
{source}
|
||||
</Text>
|
||||
{installJob.error_reason && (
|
||||
<Text color="error.500">
|
||||
{t('queue.failed')}: {installJob.error}
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user