mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 18:18:12 -05:00
Compare commits
1 Commits
feat/api/i
...
kyle0654/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
510ae34bff |
6
.coveragerc
Normal file
6
.coveragerc
Normal file
@@ -0,0 +1,6 @@
|
||||
[run]
|
||||
omit='.env/*'
|
||||
source='.'
|
||||
|
||||
[report]
|
||||
show_missing = true
|
||||
24
.github/CODEOWNERS
vendored
24
.github/CODEOWNERS
vendored
@@ -1,16 +1,16 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
/.github/workflows/ @mauwii @lstein
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
/docs/ @lstein @mauwii @tildebyte
|
||||
/mkdocs.yml @lstein @mauwii
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
@@ -18,17 +18,17 @@
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
|
||||
|
||||
19
.github/stale.yaml
vendored
19
.github/stale.yaml
vendored
@@ -1,19 +0,0 @@
|
||||
# Number of days of inactivity before an issue becomes stale
|
||||
daysUntilStale: 28
|
||||
# Number of days of inactivity before a stale issue is closed
|
||||
daysUntilClose: 14
|
||||
# Issues with these labels will never be considered stale
|
||||
exemptLabels:
|
||||
- pinned
|
||||
- security
|
||||
# Label to use when marking an issue as stale
|
||||
staleLabel: stale
|
||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Please
|
||||
update the ticket if this is still a problem on the latest release.
|
||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||
closeComment: >
|
||||
Due to inactivity, this issue has been automatically closed. If this is
|
||||
still a problem on the latest release, please recreate the issue.
|
||||
4
.github/workflows/build-container.yml
vendored
4
.github/workflows/build-container.yml
vendored
@@ -16,10 +16,6 @@ on:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
18
.github/workflows/mkdocs-material.yml
vendored
18
.github/workflows/mkdocs-material.yml
vendored
@@ -2,19 +2,13 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
- 'main'
|
||||
- 'development'
|
||||
|
||||
jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@@ -25,15 +19,11 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install ".[docs]"
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
@@ -43,7 +33,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
1
.github/workflows/test-invoke-pip-skip.yml
vendored
1
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
- '!invokeai/frontend/web/dist/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
21
.github/workflows/test-invoke-pip.yml
vendored
21
.github/workflows/test-invoke-pip.yml
vendored
@@ -7,11 +7,13 @@ on:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -80,6 +82,11 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
@@ -100,6 +107,12 @@ jobs:
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
id: run-preload-models
|
||||
env:
|
||||
@@ -118,21 +131,15 @@ jobs:
|
||||
HF_HUB_OFFLINE: 1
|
||||
HF_DATASETS_OFFLINE: 1
|
||||
TRANSFORMERS_OFFLINE: 1
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
run: >
|
||||
invokeai
|
||||
--no-patchmatch
|
||||
--no-nsfw_checker
|
||||
--precision=float32
|
||||
--always_use_cpu
|
||||
--use_memory_db
|
||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
|
||||
- name: Archive results
|
||||
id: archive-results
|
||||
env:
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -9,8 +9,6 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
@@ -65,7 +63,6 @@ pip-delete-this-directory.txt
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coveragerc
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
@@ -76,7 +73,6 @@ cov.xml
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.pytest.ini
|
||||
cover/
|
||||
junit/
|
||||
|
||||
@@ -201,8 +197,6 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/web/dist/*
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
||||
5
.pytest.ini
Normal file
5
.pytest.ini
Normal file
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = webtas.settings
|
||||
; python_files = tests.py test_*.py *_tests.py
|
||||
|
||||
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml
|
||||
13
README.md
13
README.md
@@ -33,8 +33,6 @@
|
||||
|
||||
</div>
|
||||
|
||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
||||
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
@@ -86,7 +84,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generation models.
|
||||
select a set of starting image generaiton models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
@@ -141,20 +139,15 @@ not supported.
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
|
||||
BIN
binary_installer/WinLongPathsEnabled.reg
Normal file
BIN
binary_installer/WinLongPathsEnabled.reg
Normal file
Binary file not shown.
164
binary_installer/install.bat.in
Normal file
164
binary_installer/install.bat.in
Normal file
@@ -0,0 +1,164 @@
|
||||
@echo off
|
||||
|
||||
@rem This script will install git (if not found on the PATH variable)
|
||||
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
@rem For users who already have git, this step will be skipped.
|
||||
|
||||
@rem Next, it'll download the project's source code.
|
||||
@rem Then it will download a self-contained, standalone Python and unpack it.
|
||||
@rem Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
@rem This enables a user to install this project without manually installing git or Python
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
echo ***** Installing InvokeAI.. *****
|
||||
@rem Config
|
||||
set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
||||
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
||||
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
||||
|
||||
set PACKAGES_TO_INSTALL=
|
||||
|
||||
call git --version >.tmp1 2>.tmp2
|
||||
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem (if necessary) install git into a contained environment
|
||||
if "%PACKAGES_TO_INSTALL%" NEQ "" (
|
||||
@rem download micromamba
|
||||
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
|
||||
|
||||
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
|
||||
|
||||
@rem test the mamba binary
|
||||
echo ***** Micromamba version: *****
|
||||
call micromamba.exe --version
|
||||
|
||||
@rem create the installer env
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
|
||||
)
|
||||
|
||||
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
|
||||
|
||||
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
|
||||
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
del /q micromamba.exe
|
||||
|
||||
@rem For 'git' only
|
||||
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
|
||||
|
||||
@rem Download/unpack/clean up InvokeAI release sourceball
|
||||
set err_msg=----- InvokeAI source download failed -----
|
||||
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
|
||||
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- InvokeAI source unpack failed -----
|
||||
tar -zxf InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q InvokeAI.tgz
|
||||
|
||||
set err_msg=----- InvokeAI source copy failed -----
|
||||
cd InvokeAI-*
|
||||
xcopy . .. /e /h
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
cd ..
|
||||
|
||||
@rem cleanup
|
||||
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
|
||||
rd /s /q .dev_scripts .github docker-build tests
|
||||
del /q requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo ***** Unpacked InvokeAI source *****
|
||||
|
||||
@rem Download/unpack/clean up python-build-standalone
|
||||
set err_msg=----- Python download failed -----
|
||||
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- Python unpack failed -----
|
||||
tar -zxf python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q python.tgz
|
||||
|
||||
echo ***** Unpacked python-build-standalone *****
|
||||
|
||||
@rem create venv
|
||||
set err_msg=----- problem creating venv -----
|
||||
.\python\python -E -s -m venv .venv
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo ***** Created Python virtual environment *****
|
||||
|
||||
@rem Print venv's Python version
|
||||
set err_msg=----- problem calling venv's python -----
|
||||
echo We're running under
|
||||
.venv\Scripts\python --version
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- pip update failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Updated pip and wheel *****
|
||||
|
||||
set err_msg=----- requirements file copy failed -----
|
||||
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- main pip install failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Installed Python dependencies *****
|
||||
|
||||
set err_msg=----- InvokeAI setup failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
copy binary_installer\invoke.bat.in .\invoke.bat
|
||||
echo ***** Installed invoke launcher script ******
|
||||
|
||||
@rem more cleanup
|
||||
rd /s /q binary_installer installer_files
|
||||
|
||||
@rem preload the models
|
||||
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
|
||||
set err_msg=----- model download clone failed -----
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
deactivate
|
||||
|
||||
echo ***** Finished downloading models *****
|
||||
|
||||
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
|
||||
pause
|
||||
exit
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
pause
|
||||
exit
|
||||
235
binary_installer/install.sh.in
Normal file
235
binary_installer/install.sh.in
Normal file
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
set -euo pipefail
|
||||
IFS=$'\n\t'
|
||||
|
||||
function _err_exit {
|
||||
if test "$1" -ne 0
|
||||
then
|
||||
echo -e "Error code $1; Error caught was '$2'"
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
# This script will install git (if not found on the PATH variable)
|
||||
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
# For users who already have git, this step will be skipped.
|
||||
|
||||
# Next, it'll download the project's source code.
|
||||
# Then it will download a self-contained, standalone Python and unpack it.
|
||||
# Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
# This enables a user to install this project without manually installing git or Python
|
||||
|
||||
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
|
||||
|
||||
export no_cache_dir="--no-cache-dir"
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ "$1" = "use-cache" ]; then
|
||||
export no_cache_dir=""
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
OS_NAME=$(uname -s)
|
||||
case "${OS_NAME}" in
|
||||
Linux*) OS_NAME="linux";;
|
||||
Darwin*) OS_NAME="darwin";;
|
||||
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
|
||||
esac
|
||||
|
||||
OS_ARCH=$(uname -m)
|
||||
case "${OS_ARCH}" in
|
||||
x86_64*) ;;
|
||||
arm64*) ;;
|
||||
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
|
||||
esac
|
||||
|
||||
# https://mamba.readthedocs.io/en/latest/installation.html
|
||||
MAMBA_OS_NAME=$OS_NAME
|
||||
MAMBA_ARCH=$OS_ARCH
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
MAMBA_OS_NAME="osx"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "linux" ]; then
|
||||
MAMBA_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "x86_64" ]; then
|
||||
MAMBA_ARCH="64"
|
||||
fi
|
||||
|
||||
PY_ARCH=$OS_ARCH
|
||||
if [ "$OS_ARCH" == "arm64" ]; then
|
||||
PY_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
# Compute device ('cd' segment of reqs files) detect goes here
|
||||
# This needs a ton of work
|
||||
# Suggestions:
|
||||
# - lspci
|
||||
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
|
||||
# - Surely there's a similar utility for AMD?
|
||||
CD="cuda"
|
||||
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
|
||||
CD="mps"
|
||||
fi
|
||||
|
||||
# config
|
||||
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
||||
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
||||
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
||||
elif [ "$OS_NAME" == "linux" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
|
||||
fi
|
||||
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
|
||||
|
||||
PACKAGES_TO_INSTALL=""
|
||||
|
||||
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
|
||||
|
||||
# (if necessary) install git and conda into a contained environment
|
||||
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
|
||||
# download micromamba
|
||||
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
|
||||
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
|
||||
|
||||
chmod u+x ./micromamba
|
||||
|
||||
# test the mamba binary
|
||||
echo -e "\n***** Micromamba version: *****\n"
|
||||
./micromamba --version
|
||||
|
||||
# create the installer env
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
|
||||
fi
|
||||
|
||||
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
|
||||
|
||||
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
|
||||
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
|
||||
exit
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -f micromamba.exe
|
||||
|
||||
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
|
||||
|
||||
# Download/unpack/clean up InvokeAI release sourceball
|
||||
_err_msg="\n----- InvokeAI source download failed -----\n"
|
||||
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- InvokeAI source unpack failed -----\n"
|
||||
tar -zxf InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f InvokeAI.tgz
|
||||
|
||||
_err_msg="\n----- InvokeAI source copy failed -----\n"
|
||||
cd InvokeAI-*
|
||||
cp -r . ..
|
||||
_err_exit $? _err_msg
|
||||
cd ..
|
||||
|
||||
# cleanup
|
||||
rm -rf InvokeAI-*/
|
||||
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo -e "\n***** Unpacked InvokeAI source *****\n"
|
||||
|
||||
# Download/unpack/clean up python-build-standalone
|
||||
_err_msg="\n----- Python download failed -----\n"
|
||||
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- Python unpack failed -----\n"
|
||||
tar -zxf python.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f python.tgz
|
||||
|
||||
echo -e "\n***** Unpacked python-build-standalone *****\n"
|
||||
|
||||
# create venv
|
||||
_err_msg="\n----- problem creating venv -----\n"
|
||||
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
# patch sysconfig so that extensions can build properly
|
||||
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
|
||||
PYTHON_INSTALL_DIR="$(pwd)/python"
|
||||
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
|
||||
TMPFILE="$(mktemp)"
|
||||
chmod +w "${SYSCONFIG}"
|
||||
cp "${SYSCONFIG}" "${TMPFILE}"
|
||||
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
|
||||
rm -f "${TMPFILE}"
|
||||
fi
|
||||
|
||||
./python/bin/python3 -E -s -m venv .venv
|
||||
_err_exit $? _err_msg
|
||||
source .venv/bin/activate
|
||||
|
||||
echo -e "\n***** Created Python virtual environment *****\n"
|
||||
|
||||
# Print venv's Python version
|
||||
_err_msg="\n----- problem calling venv's python -----\n"
|
||||
echo -e "We're running under"
|
||||
.venv/bin/python3 --version
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- pip update failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Updated pip *****\n"
|
||||
|
||||
_err_msg="\n----- requirements file copy failed -----\n"
|
||||
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- main pip install failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed Python dependencies *****\n"
|
||||
|
||||
_err_msg="\n----- InvokeAI setup failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed InvokeAI *****\n"
|
||||
|
||||
cp binary_installer/invoke.sh.in ./invoke.sh
|
||||
chmod a+rx ./invoke.sh
|
||||
echo -e "\n***** Installed invoke launcher script ******\n"
|
||||
|
||||
# more cleanup
|
||||
rm -rf binary_installer/ installer_files/
|
||||
|
||||
# preload the models
|
||||
.venv/bin/python3 scripts/configure_invokeai.py
|
||||
_err_msg="\n----- model download clone failed -----\n"
|
||||
_err_exit $? _err_msg
|
||||
deactivate
|
||||
|
||||
echo -e "\n***** Finished downloading models *****\n"
|
||||
|
||||
echo "All done! Run the command"
|
||||
echo " $scriptdir/invoke.sh"
|
||||
echo "to start InvokeAI."
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
36
binary_installer/invoke.bat.in
Normal file
36
binary_installer/invoke.bat.in
Normal file
@@ -0,0 +1,36 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo OR
|
||||
echo 3. open the developer console
|
||||
set /p choice="Please enter 1, 2 or 3: "
|
||||
if /i "%choice%" == "1" (
|
||||
echo Starting the InvokeAI command-line.
|
||||
.venv\Scripts\python scripts\invoke.py %*
|
||||
) else if /i "%choice%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI.
|
||||
.venv\Scripts\python scripts\invoke.py --web %*
|
||||
) else if /i "%choice%" == "3" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) else (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
deactivate
|
||||
46
binary_installer/invoke.sh.in
Normal file
46
binary_installer/invoke.sh.in
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
set -eu
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line"
|
||||
echo "2. browser-based UI"
|
||||
echo "OR"
|
||||
echo "3. open the developer console"
|
||||
echo "Please enter 1, 2, or 3:"
|
||||
read choice
|
||||
|
||||
case $choice in
|
||||
1)
|
||||
printf "\nStarting the InvokeAI command-line..\n";
|
||||
.venv/bin/python scripts/invoke.py $*;
|
||||
;;
|
||||
2)
|
||||
printf "\nStarting the InvokeAI browser-based UI..\n";
|
||||
.venv/bin/python scripts/invoke.py --web $*;
|
||||
;;
|
||||
3)
|
||||
printf "\nDeveloper Console:\n";
|
||||
printf "Python command is:\n\t";
|
||||
which python;
|
||||
printf "Python version is:\n\t";
|
||||
python --version;
|
||||
echo "*************************"
|
||||
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
|
||||
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
|
||||
printf "*************************\n"
|
||||
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
|
||||
/usr/bin/env "$SHELL";
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection";
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
2097
binary_installer/py3.10-darwin-arm64-mps-reqs.txt
Normal file
2097
binary_installer/py3.10-darwin-arm64-mps-reqs.txt
Normal file
File diff suppressed because it is too large
Load Diff
2077
binary_installer/py3.10-darwin-x86_64-cpu-reqs.txt
Normal file
2077
binary_installer/py3.10-darwin-x86_64-cpu-reqs.txt
Normal file
File diff suppressed because it is too large
Load Diff
2103
binary_installer/py3.10-linux-x86_64-cuda-reqs.txt
Normal file
2103
binary_installer/py3.10-linux-x86_64-cuda-reqs.txt
Normal file
File diff suppressed because it is too large
Load Diff
2109
binary_installer/py3.10-windows-x86_64-cuda-reqs.txt
Normal file
2109
binary_installer/py3.10-windows-x86_64-cuda-reqs.txt
Normal file
File diff suppressed because it is too large
Load Diff
17
binary_installer/readme.txt
Normal file
17
binary_installer/readme.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Installation on Windows:
|
||||
NOTE: You might need to enable Windows Long Paths. If you're not sure,
|
||||
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
|
||||
file. Note that you will need to have admin privileges in order to
|
||||
do this.
|
||||
|
||||
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
|
||||
|
||||
Installation on Linux and Mac:
|
||||
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
|
||||
|
||||
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
|
||||
file (on Linux/Mac) to start InvokeAI.
|
||||
33
binary_installer/requirements.in
Normal file
33
binary_installer/requirements.in
Normal file
@@ -0,0 +1,33 @@
|
||||
--prefer-binary
|
||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--trusted-host https://download.pytorch.org
|
||||
accelerate~=0.15
|
||||
albumentations
|
||||
diffusers[torch]~=0.11
|
||||
einops
|
||||
eventlet
|
||||
flask_cors
|
||||
flask_socketio
|
||||
flaskwebgui==1.0.3
|
||||
getpass_asterisk
|
||||
imageio-ffmpeg
|
||||
pyreadline3
|
||||
realesrgan
|
||||
send2trash
|
||||
streamlit
|
||||
taming-transformers-rom1504
|
||||
test-tube
|
||||
torch-fidelity
|
||||
torch==1.12.1 ; platform_system == 'Darwin'
|
||||
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
torchvision==0.13.1 ; platform_system == 'Darwin'
|
||||
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
transformers
|
||||
picklescan
|
||||
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
|
||||
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
|
||||
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
|
||||
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
|
||||
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip
|
||||
4
coverage/.gitignore
vendored
4
coverage/.gitignore
vendored
@@ -1,4 +0,0 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 470 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 457 KiB |
@@ -1,277 +1,105 @@
|
||||
# Invocations
|
||||
|
||||
Invocations represent a single operation, its inputs, and its outputs. These
|
||||
operations and their outputs can be chained together to generate and modify
|
||||
images.
|
||||
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||
|
||||
## Creating a new invocation
|
||||
|
||||
To create a new invocation, either find the appropriate module file in
|
||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
||||
that folder. All invocations in that folder will be discovered and made
|
||||
available to the CLI and API automatically. Invocations make use of
|
||||
[typing](https://docs.python.org/3/library/typing.html) and
|
||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
||||
into the CLI and API.
|
||||
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||
|
||||
An invocation looks like this:
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Each portion is important to implement correctly.
|
||||
|
||||
### Class definition and type
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
```
|
||||
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
||||
that declares what they do in a single, short line. They should also have a
|
||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
||||
is what the user will type on the CLI or use in the API to create this
|
||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
||||
the value of the literal in the type hint.
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||
|
||||
### Inputs
|
||||
|
||||
```py
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||
```
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||
| Part | Value | Description |
|
||||
| ---- | ----- | ----------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
||||
description, and validation information. For example:
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||
|
||||
| Part | Value | Description |
|
||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
||||
field to be parsed with `None` as a value, which enables linking to previous
|
||||
invocations. All fields should either provide a default value or allow `None` as
|
||||
a value, so that they can be overwritten with a linked output from another
|
||||
invocation.
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as
|
||||
`ImageField`, which protects them from pydantic validation errors (since images
|
||||
only ever come from links).
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Config
|
||||
|
||||
```py
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
This is an optional configuration for the invocation. It inherits from
|
||||
pydantic's model `Config` class, and it used primarily to customize the
|
||||
autogenerated OpenAPI schema.
|
||||
|
||||
The UI relies on the OpenAPI schema in two ways:
|
||||
|
||||
- An API client & Typescript types are generated from it. This happens at build
|
||||
time.
|
||||
- The node editor parses the schema into a template used by the UI to create the
|
||||
node editor UI. This parsing happens at runtime.
|
||||
|
||||
In this example, a `ui` key has been added to the `schema_extra` dict to provide
|
||||
some tags for the UI, to facilitate filtering nodes.
|
||||
|
||||
See the Schema Generation section below for more information.
|
||||
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||
|
||||
The `invoke` function is the last portion of an invocation. It is provided an
|
||||
`InvocationContext` which contains services to perform work as well as a
|
||||
`session_id` for use as needed. It should return a class with output values that
|
||||
derives from `BaseInvocationOutput`.
|
||||
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Before being called, the invocation will have all of its fields set from
|
||||
defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Assume that this invocation may be running simultaneously with other
|
||||
invocations, may be running on another machine, or in other interesting
|
||||
scenarios. If you need functionality, please provide it as a service in the
|
||||
`InvocationServices` class, and make sure it can be overridden.
|
||||
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||
|
||||
### Outputs
|
||||
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
to use an existing output class if available, and prefer to name inputs the same
|
||||
as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
## Schema Generation
|
||||
|
||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
||||
|
||||
### Required Properties
|
||||
|
||||
The schema generation treat all properties with default values as optional. This
|
||||
makes sense internally, but when when using these classes via the generated
|
||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
||||
marked as optional.
|
||||
|
||||
We know that this property will always be present, so the additional logic
|
||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
||||
|
||||
To fix this, we can leverage `pydantic`'s
|
||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
||||
to mark properties that we know will always be present as required.
|
||||
|
||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
```
|
||||
|
||||
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
|
||||
`image`, `width` and `height` properties marked as optional, even though we know
|
||||
they will always have a value.
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
# Add schema customization
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
```
|
||||
|
||||
With the customization in place, the schema will now show these properties as
|
||||
required, obviating the need for extensive null checks in client code.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
# Local Development
|
||||
|
||||
If you are looking to contribute you will need to have a local development
|
||||
environment. See the
|
||||
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
|
||||
full details.
|
||||
|
||||
Broadly this involves cloning the repository, installing the pre-reqs, and
|
||||
InvokeAI (in editable form). Assuming this is working, choose your area of
|
||||
focus.
|
||||
|
||||
## Documentation
|
||||
|
||||
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
|
||||
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
|
||||
written in markdown files under the `./docs` folder and then built into a static
|
||||
website for hosting with GitHub Pages at
|
||||
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
|
||||
|
||||
To contribute to the documentation you'll need to install the dependencies. Note
|
||||
the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[docs]"
|
||||
```
|
||||
|
||||
Now, to run the documentation locally with hot-reloading for changes made.
|
||||
|
||||
```zsh
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
|
||||
access.
|
||||
|
||||
## Backend
|
||||
|
||||
The backend is contained within the `./invokeai/backend` folder structure. To
|
||||
get started however please install the development dependencies.
|
||||
|
||||
From the root of the repository run the following command. Note the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[test]"
|
||||
```
|
||||
|
||||
This in an optional group of packages which is defined within the
|
||||
`pyproject.toml` and will be required for testing the changes you make the the
|
||||
code.
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
|
||||
be found under the `./tests` folder and can be run with a single `pytest`
|
||||
command. Optionally, to review test coverage you can append `--cov`.
|
||||
|
||||
```zsh
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition a more
|
||||
detailed report is created in both XML and HTML format in the `./coverage`
|
||||
folder. The HTML one in particular can help identify missing statements
|
||||
requiring tests to ensure coverage. This can be run by opening
|
||||
`./coverage/html/index.html`.
|
||||
|
||||
For example.
|
||||
|
||||
```zsh
|
||||
pytest --cov; open ./coverage/html/index.html
|
||||
```
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Front End
|
||||
|
||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||
|
||||
--8<-- "invokeai/frontend/web/README.md"
|
||||
@@ -168,15 +168,11 @@ used by Stable Diffusion 1.4 and 1.5.
|
||||
After installation, your `models.yaml` should contain an entry that looks like
|
||||
this one:
|
||||
|
||||
```yml
|
||||
inpainting-1.5:
|
||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
description: SD inpainting v1.5
|
||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
```
|
||||
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
description: SD inpainting v1.5 config:
|
||||
configs/stable-diffusion/v1-inpainting-inference.yaml vae:
|
||||
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512
|
||||
height: 512
|
||||
|
||||
As shown in the example, you may include a VAE fine-tuning weights file as well.
|
||||
This is strongly recommended.
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: Controlling Logging
|
||||
---
|
||||
|
||||
# :material-image-off: Controlling Logging
|
||||
|
||||
## Controlling How InvokeAI Logs Status Messages
|
||||
|
||||
InvokeAI logs status messages using a configurable logging system. You
|
||||
can log to the terminal window, to a designated file on the local
|
||||
machine, to the syslog facility on a Linux or Mac, or to a properly
|
||||
configured web server. You can configure several logs at the same
|
||||
time, and control the level of message logged and the logging format
|
||||
(to a limited extent).
|
||||
|
||||
Three command-line options control logging:
|
||||
|
||||
### `--log_handlers <handler1> <handler2> ...`
|
||||
|
||||
This option activates one or more log handlers. Options are "console",
|
||||
"file", "syslog" and "http". To specify more than one, separate them
|
||||
by spaces:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
|
||||
```
|
||||
|
||||
The format of these options is described below.
|
||||
|
||||
### `--log_format {plain|color|legacy|syslog}`
|
||||
|
||||
This controls the format of log messages written to the console. Only
|
||||
the "console" log handler is currently affected by this setting.
|
||||
|
||||
* "plain" provides formatted messages like this:
|
||||
|
||||
```bash
|
||||
|
||||
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||
```
|
||||
|
||||
* "color" produces similar output, but the text will be color coded to
|
||||
indicate the severity of the message.
|
||||
|
||||
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||
|
||||
```bash
|
||||
### this is a critical error
|
||||
*** this is an error
|
||||
** this is a warning
|
||||
>> this is an informational messages
|
||||
| this is a debug message
|
||||
```
|
||||
|
||||
* "syslog" produces messages suitable for syslog entries:
|
||||
|
||||
```bash
|
||||
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||
InvokeAI [2691178] <ERROR> this is an error
|
||||
InvokeAI [2691178] <WARNING> this is a warning
|
||||
InvokeAI [2691178] <INFO> this is an informational messages
|
||||
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||
```
|
||||
|
||||
(note that the date, time and hostname will be added by the syslog
|
||||
system)
|
||||
|
||||
### `--log_level {debug|info|warning|error|critical}`
|
||||
|
||||
Providing this command-line option will cause only messages at the
|
||||
specified level or above to be emitted.
|
||||
|
||||
## Console logging
|
||||
|
||||
When "console" is provided to `--log_handlers`, messages will be
|
||||
written to the command line window in which InvokeAI was launched. By
|
||||
default, the color formatter will be used unless overridden by
|
||||
`--log_format`.
|
||||
|
||||
## File logging
|
||||
|
||||
When "file" is provided to `--log_handlers`, entries will be written
|
||||
to the file indicated in the path argument. By default, the "plain"
|
||||
format will be used:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
## Syslog logging
|
||||
|
||||
When "syslog" is requested, entries will be sent to the syslog
|
||||
system. There are a variety of ways to control where the log message
|
||||
is sent:
|
||||
|
||||
* Send to the local machine using the `/dev/log` socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=/dev/log
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message on a nonstandard
|
||||
port:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost:512
|
||||
```
|
||||
|
||||
* Send to a remote machine named "loghost" on the local LAN using
|
||||
facility LOG_USER and UDP packets:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||
```
|
||||
|
||||
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
|
||||
are defaults.
|
||||
|
||||
* Send to a remote machine named "loghost" using the facility LOCAL0
|
||||
and using a TCP socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||
```
|
||||
|
||||
If no arguments are specified (just a bare "syslog"), then the logging
|
||||
system will look for a UNIX socket named `/dev/log`, and if not found
|
||||
try to send a UDP message to `localhost`. The Macintosh OS used to
|
||||
support logging to a socket named `/var/run/syslog`, but this feature
|
||||
has since been disabled.
|
||||
|
||||
## Web logging
|
||||
|
||||
If you have access to a web server that is configured to log messages
|
||||
when a particular URL is requested, you can log using the "http"
|
||||
method:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||
```
|
||||
|
||||
The optional [,method=] part can be used to specify whether the URL
|
||||
accepts GET (default) or POST messages.
|
||||
|
||||
Currently password authentication and SSL are not supported.
|
||||
|
||||
## Using the configuration file
|
||||
|
||||
You can set and forget logging options by adding a "Logging" section
|
||||
to `invokeai.yaml`:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
[... other settings...]
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=/dev/log
|
||||
log_level: info
|
||||
log_format: color
|
||||
```
|
||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file (usually
|
||||
`invokeai.init` in your home directory). You can change the default at any
|
||||
`.invokeai` in your home directory). You can change the default at any
|
||||
time by opening this file in a text editor and commenting or
|
||||
uncommenting the line `--nsfw_checker`.
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
||||
masking option:
|
||||
|
||||
```bash
|
||||
invoke> a fluffy cat eating a hotdog
|
||||
invoke> a fluffy cat eating a hotdot
|
||||
Outputs:
|
||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||
|
||||
@@ -17,7 +17,7 @@ notebooks.
|
||||
|
||||
You will need a GPU to perform training in a reasonable length of
|
||||
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
|
||||
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
|
||||
library](../installation/070_INSTALL_XFORMERS) to accelerate the
|
||||
training process further. During training, about ~8 GB is temporarily
|
||||
needed in order to store intermediate models, checkpoints and logs.
|
||||
|
||||
|
||||
@@ -57,9 +57,6 @@ Personalize models by adding your own style or subjects.
|
||||
## * [The NSFW Checker](NSFW.md)
|
||||
Prevent InvokeAI from displaying unwanted racy images.
|
||||
|
||||
## * [Controlling Logging](LOGGING.md)
|
||||
Control how InvokeAI logs status messages.
|
||||
|
||||
## * [Miscellaneous](OTHER.md)
|
||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||
batch process a file of prompts, increase the "creativity" of image
|
||||
|
||||
@@ -89,7 +89,7 @@ experimental versions later.
|
||||
sudo apt update
|
||||
sudo apt install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
||||
sudo apt install python3.10 python3-pip python3.10-venv
|
||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||
```
|
||||
|
||||
@@ -417,7 +417,7 @@ Then type the following commands:
|
||||
|
||||
=== "AMD System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
### Corrupted configuration file
|
||||
|
||||
@@ -154,7 +154,7 @@ manager, please follow these steps:
|
||||
=== "ROCm (AMD)"
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
@@ -216,7 +216,7 @@ manager, please follow these steps:
|
||||
9. Run the command-line- or the web- interface:
|
||||
|
||||
From within INVOKEAI_ROOT, activate the environment
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
|
||||
the script `invokeai`. If the virtual environment you selected is NOT inside
|
||||
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
|
||||
`--root_dir \path\to\invokeai` to the commands below:
|
||||
@@ -315,7 +315,7 @@ installation protocol (important!)
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
```bash
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
|
||||
@@ -110,7 +110,7 @@ recipes are available
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
|
||||
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
This will be done automatically for you if you use the installer
|
||||
|
||||
@@ -50,7 +50,7 @@ subset that are currently installed are found in
|
||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||
|
||||
@@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
|
||||
brew install opencv
|
||||
```
|
||||
|
||||
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
|
||||
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
|
||||
|
||||
## Linux
|
||||
|
||||
@@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
|
||||
`python`, and then at the `>>>` line type
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
`from patchmatch import patch_match`: It should look like the follwing:
|
||||
|
||||
```py
|
||||
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
|
||||
@@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
|
||||
|
||||
If you see no errors you're ready to go!
|
||||
If you see no errors, then you're ready to go!
|
||||
|
||||
@@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
@@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
optional_modules = None
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
|
||||
@@ -7,42 +7,42 @@ call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Explore InvokeAI nodes using a command-line interface
|
||||
echo 3. Run textual inversion training
|
||||
echo 4. Merge models (diffusers type only)
|
||||
echo 5. Download and install models
|
||||
echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [2] "
|
||||
if not defined choice set choice=2
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line interface
|
||||
echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. download and install models
|
||||
echo 6. change InvokeAI startup options
|
||||
echo 7. re-run the configure script to fix a broken install
|
||||
echo 8. open the developer console
|
||||
echo 9. update InvokeAI
|
||||
echo 10. command-line help
|
||||
echo Q - quit
|
||||
set /P restore="Please enter 1-10, Q: [2] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
python .venv\Scripts\invokeai.exe %*
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
) ELSE IF /I "%restore%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai.exe --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%choice%" == "6" (
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@@ -54,15 +54,15 @@ IF /I "%choice%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
) ELSE IF /I "%restore%" == "9" (
|
||||
echo Running invokeai-update...
|
||||
python .venv\Scripts\invokeai-update.exe %*
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
) ELSE IF /I "%restore%" == "10" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
) ELSE IF /I "%restore%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
@@ -16,168 +11,85 @@
|
||||
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
PARAMS=$@
|
||||
|
||||
# Check to see if dialog is installed (it seems to be fairly standard, but good to check regardless) and if the user has passed the --no-tui argument to disable the dialog TUI
|
||||
tui=true
|
||||
if command -v dialog &>/dev/null; then
|
||||
# This must use $@ to properly loop through the arguments passed by the user
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "--no-tui" ]; then
|
||||
tui=false
|
||||
# Remove the --no-tui argument to avoid errors later on when passing arguments to InvokeAI
|
||||
PARAMS=$(echo "$PARAMS" | sed 's/--no-tui//')
|
||||
break
|
||||
fi
|
||||
done
|
||||
else
|
||||
tui=false
|
||||
fi
|
||||
|
||||
# Set required env var for torch on mac MPS
|
||||
# set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Explore InvokeAI nodes using a command-line interface\n"
|
||||
invokeai $PARAMS
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
;;
|
||||
4)
|
||||
clear
|
||||
printf "Merge models (diffusers type only)\n"
|
||||
invokeai-merge --gui $PARAMS
|
||||
;;
|
||||
5)
|
||||
clear
|
||||
printf "Download and install models\n"
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
clear
|
||||
printf "Change InvokeAI startup options\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
clear
|
||||
printf "Update InvokeAI\n"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
"HELP 1")
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
clear
|
||||
}
|
||||
|
||||
# Dialog-based TUI for launcing Invoke functions
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Generate images using a command-line interface"
|
||||
3 "Textual inversion training"
|
||||
4 "Merge models (diffusers type only)"
|
||||
5 "Download and install models"
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
--colors \
|
||||
--title "What would you like to do?" \
|
||||
--ok-label "Run" \
|
||||
--cancel-label "Exit" \
|
||||
--help-button \
|
||||
--help-label "CLI Help" \
|
||||
--menu "Select an option:" \
|
||||
0 0 0 \
|
||||
"${options[@]}" \
|
||||
2>&1 >/dev/tty) || clear
|
||||
do_choice "$choice"
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Explore InvokeAI nodes using the command-line interface\n"
|
||||
printf "3: Run textual inversion training\n"
|
||||
printf "4: Merge models (diffusers type only)\n"
|
||||
printf "5: Download and install models\n"
|
||||
printf "6: Change InvokeAI startup options\n"
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke with either the TUI or CLI, and for checking if the user is in the developer console
|
||||
while true
|
||||
do
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
if $tui; then
|
||||
# .dialogrc must be located in the same directory as the invoke.sh script
|
||||
export DIALOGRC="./.dialogrc"
|
||||
do_dialog
|
||||
else
|
||||
do_line_input
|
||||
fi
|
||||
done
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line interface"
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. download and install models"
|
||||
echo "6. change InvokeAI startup options"
|
||||
echo "7. re-run the configure script to fix a broken install"
|
||||
echo "8. open the developer console"
|
||||
echo "9. update InvokeAI"
|
||||
echo "10. command-line help"
|
||||
echo "Q - Quit"
|
||||
echo ""
|
||||
read -p "Please enter 1-10, Q: [2] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1)
|
||||
echo "Starting the InvokeAI command-line..."
|
||||
invokeai $@
|
||||
;;
|
||||
2)
|
||||
echo "Starting the InvokeAI browser-based UI..."
|
||||
invokeai --web $@
|
||||
;;
|
||||
3)
|
||||
echo "Starting Textual Inversion:"
|
||||
invokeai-ti --gui $@
|
||||
;;
|
||||
4)
|
||||
echo "Merging Models:"
|
||||
invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
echo "Update:"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
invokeai --help
|
||||
;;
|
||||
[qQ])
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection"
|
||||
exit;;
|
||||
esac
|
||||
else # in developer console
|
||||
python --version
|
||||
printf "Press ^D to exit\n"
|
||||
echo "Press ^D to exit"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from argparse import Namespace
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
@@ -39,67 +32,45 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||
def initialize(config, event_handler_id: int):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = config.output_path
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config, logger),
|
||||
model_manager=get_model_manager(config),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config, logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -45,7 +45,7 @@ class FastAPIEventService(EventServiceBase):
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
|
||||
@@ -1,250 +1,56 @@
|
||||
import io
|
||||
from typing import Optional
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/",
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
415: {"description": "Image upload failed"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
async def upload_image(file: UploadFile, request: Request):
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
return Response(status_code=415)
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
return Response(status_code=415)
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.EXTERNAL,
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(
|
||||
image_origin, image_name, image_changes
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
)
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_origin: ResourceOrigin = Path(
|
||||
description="The type of full-resolution image file to get"
|
||||
),
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=image_name,
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the image thumbnail",
|
||||
"content": {"image/webp": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images_with_metadata",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images_with_metadata(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.responses import Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
@@ -50,7 +51,7 @@ async def list_sessions(
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
if filter == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
@@ -75,7 +76,7 @@ async def get_session(
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
@@ -98,7 +99,7 @@ async def add_node(
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
@@ -107,9 +108,9 @@ async def add_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -131,7 +132,7 @@ async def update_node(
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
@@ -140,9 +141,9 @@ async def update_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
@@ -161,7 +162,7 @@ async def delete_node(
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
@@ -170,9 +171,9 @@ async def delete_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
@@ -191,7 +192,7 @@ async def add_edge(
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
@@ -200,9 +201,9 @@ async def add_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@@ -225,7 +226,7 @@ async def delete_edge(
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
@@ -238,9 +239,9 @@ async def delete_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -258,14 +259,14 @@ async def invoke_session(
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> Response:
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
return Response(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
raise HTTPException(status_code=400)
|
||||
return Response(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
@@ -280,7 +281,7 @@ async def invoke_session(
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> Response:
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
@@ -3,7 +3,6 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@@ -11,21 +10,13 @@ from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
#This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.routers import images, sessions
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# Create the app
|
||||
@@ -42,21 +33,30 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
config=config, event_handler_id=event_handler_id
|
||||
)
|
||||
|
||||
|
||||
@@ -74,10 +74,9 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
@@ -123,7 +122,8 @@ def custom_openapi():
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
@@ -143,20 +143,16 @@ def overridden_redoc():
|
||||
)
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/",
|
||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||
html=True
|
||||
), name="ui"
|
||||
)
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -2,46 +2,14 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@@ -66,26 +34,30 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
add_field_argument(command_parser, name, field)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@@ -93,38 +65,17 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
@@ -230,7 +181,7 @@ class HistoryCommand(BaseCommand):
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
@@ -249,55 +200,3 @@ class SetDefaultCommand(BaseCommand):
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.execution_graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
get_subactions = action._get_subactions
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._indent()
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||
yield subaction
|
||||
else:
|
||||
for subaction in get_subactions():
|
||||
yield subaction
|
||||
self._dedent()
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
"""
|
||||
Readline helper functions for cli_app.py
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object.
|
||||
"""
|
||||
import atexit
|
||||
import readline
|
||||
import shlex
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
class Completer(object):
|
||||
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
self.linebuffer = None
|
||||
self.manager = model_manager
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
"""
|
||||
Complete commands and switches fromm the node CLI command line.
|
||||
Switches are determined in a context-specific manner.
|
||||
"""
|
||||
|
||||
buffer = readline.get_line_buffer()
|
||||
if state == 0:
|
||||
options = None
|
||||
try:
|
||||
current_command, current_switch = self.get_current_command(buffer)
|
||||
options = self.get_command_options(current_command, current_switch)
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
self.matches = [s for s in options if s and s.startswith(text)]
|
||||
|
||||
try:
|
||||
match = self.matches[state]
|
||||
except IndexError:
|
||||
match = None
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self)->List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
switch = None
|
||||
for t in tokens:
|
||||
if t[0].isalpha():
|
||||
if switch is None:
|
||||
command = t
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(' '):
|
||||
switch=None
|
||||
return command or '', switch or ''
|
||||
|
||||
def parse_commands(self)->Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
"""
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints['type'])[0]
|
||||
result.update({name:hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
"""
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and '=' in switch:
|
||||
switch, argument = switch.split('=')
|
||||
|
||||
parameter = switch.strip('-')
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == 'model':
|
||||
return self.manager.model_names()
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except:
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
@@ -2,43 +2,25 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
|
||||
from .cli.completer import set_autocompleter
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
@@ -53,6 +35,7 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@@ -60,7 +43,7 @@ def add_invocation_args(command_parser):
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
@@ -71,9 +54,9 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
@@ -89,72 +72,20 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@@ -162,15 +93,12 @@ def generate_matching_edges(
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
)
|
||||
for alias in matching_fields
|
||||
for field in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@@ -190,111 +118,62 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
context.invoker.services.logger.error(
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
def invoke_cli():
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=images,
|
||||
images=DiskImageStorage(output_folder),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
parser = get_command_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
set_autocompleter(services)
|
||||
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
while True:
|
||||
try:
|
||||
if command_line_args_exist:
|
||||
cmd_input = invocation_commands.pop(0)
|
||||
done = len(invocation_commands) == 0
|
||||
else:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
cmd_input = input("> ")
|
||||
except KeyboardInterrupt:
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(context.nodes_added)
|
||||
start_id = len(history)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@@ -310,24 +189,8 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@@ -338,7 +201,6 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@@ -351,20 +213,16 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command, context
|
||||
from_node, command.command
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
node_id = link
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
link_node = context.session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command, context
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@@ -372,20 +230,13 @@ def invoke_cli():
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -394,25 +245,22 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.add_node(command.command)
|
||||
context.session.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.add_edge(edge)
|
||||
context.session.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ValidationError:
|
||||
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
from typing import get_args, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@@ -78,59 +75,4 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.start + self.size, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
@@ -1,266 +0,0 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from ...backend.prompting.conditioning import try_parse_legacy_blend
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment, Conjunction,
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
|
||||
if legacy_blend is not None:
|
||||
conjunction = legacy_blend
|
||||
else:
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts)>1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
@@ -1,459 +0,0 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
LineartDetector,
|
||||
LineartAnimeDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
ContentShuffleDetector,
|
||||
ZoeDetector,
|
||||
MediapipeFaceDetector,
|
||||
)
|
||||
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: move image processors to separate file (image_analysis.py
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# FIXME: what happened to image metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width = image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height = image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
@@ -5,48 +5,33 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
@@ -55,20 +40,11 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,50 +1,29 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel
|
||||
|
||||
from torch import Tensor
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
@@ -53,87 +32,72 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.control_image.image_origin, self.control_image.image_name
|
||||
)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
txt2img = Txt2Img(model, control_model=control_model)
|
||||
outputs = txt2img.generate(
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
# (right now uses whatever current model is set in model manager)
|
||||
model= context.services.model_manager.get_model()
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
control_image=control_image,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "control_image" }
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, generate_output.image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
@@ -152,72 +116,49 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
result_image = generator_output.image
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@@ -225,39 +166,6 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
)
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(
|
||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||
)
|
||||
tile_size: int = Field(
|
||||
default=32, ge=1, description="The tile infill method size (px)"
|
||||
)
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
)
|
||||
inpaint_width: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the inpaint region (px)",
|
||||
)
|
||||
inpaint_height: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the inpaint region (px)",
|
||||
)
|
||||
inpaint_fill: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The solid infill method color",
|
||||
)
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
@@ -265,70 +173,49 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
manager = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
result_image = generator_output.image
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,86 +1,54 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional, Union
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
#fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
#fmt: on
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
#fomt: on
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
#fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
)
|
||||
|
||||
|
||||
@@ -90,13 +58,11 @@ class ShowImageInvocation(BaseInvocation):
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
@@ -105,31 +71,27 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
image_type=self.image.image_type, image_name=self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
#fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(
|
||||
@@ -137,53 +99,41 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_crop,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
#fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(
|
||||
self.base_image.image_origin, self.base_image.image_name
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@@ -199,195 +149,58 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=new_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
||||
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(
|
||||
self.image1.image_origin, self.image1.image_name
|
||||
)
|
||||
image2 = context.services.images.get_pil_image(
|
||||
self.image2.image_origin, self.image2.image_name
|
||||
)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=multiply_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=channel_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=converted_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class BlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
#fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
blur = (
|
||||
@@ -397,149 +210,30 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=blur_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
PIL_RESAMPLING_MODES = Literal[
|
||||
"nearest",
|
||||
"box",
|
||||
"bilinear",
|
||||
"hamming",
|
||||
"bicubic",
|
||||
"lanczos",
|
||||
]
|
||||
|
||||
|
||||
PIL_RESAMPLING_MAP = {
|
||||
"nearest": Image.Resampling.NEAREST,
|
||||
"box": Image.Resampling.BOX,
|
||||
"bilinear": Image.Resampling.BILINEAR,
|
||||
"hamming": Image.Resampling.HAMMING,
|
||||
"bicubic": Image.Resampling.BICUBIC,
|
||||
"lanczos": Image.Resampling.LANCZOS,
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
resize_image = image.resize(
|
||||
(self.width, self.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
height = int(image.height * self.scale_factor)
|
||||
|
||||
resize_image = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class LerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
#fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
@@ -547,40 +241,30 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=lerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
#fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
@@ -593,20 +277,11 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=ilerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
color: ColorField = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -1,626 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
from compel import Compel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from .controlnet_image_processors import ControlField
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline, ControlNetModel
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(latent_channels, 4)
|
||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
device=use_device,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
# if self.perlin > 0.0:
|
||||
# perlin_noise = self.get_perlin_noise(
|
||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||
# )
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||
return v % SEED_MAX
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
# if isinstance(model, DiffusionPipeline):
|
||||
# for component in [model.unet, model.vae]:
|
||||
# configure_model_padding(component,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
# else:
|
||||
# configure_model_padding(model,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
def prep_control_data(self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
if control_input is None:
|
||||
# print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
# print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
# print("control input is ControlField")
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
# print("control input is list[ControlField]")
|
||||
control_list = control_input
|
||||
else:
|
||||
# print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(control_name,
|
||||
subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||
control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
# what happened to metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# new (post Image service refactor) way of using services to save image
|
||||
# and gnenerate unique image_name
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
@@ -1,109 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
@@ -1,237 +0,0 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
|
||||
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||
logger.debug("start_step: " + str(start_step))
|
||||
logger.debug("end_step: " + str(end_step))
|
||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
logger.debug("num_presteps: " + str(num_presteps))
|
||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
logger.debug("prelist: " + str(prelist))
|
||||
logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput, FloatOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
@@ -12,11 +12,3 @@ class PromptOutput(BaseInvocationOutput):
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'prompt',
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,35 +1,26 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["restoration", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@@ -41,20 +32,11 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,37 +1,30 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
|
||||
# fmt: off
|
||||
#fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@@ -43,20 +36,11 @@ class UpscaleInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
@@ -1,3 +0,0 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
||||
@@ -1,93 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_origin: ResourceOrigin = Field(
|
||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
@@ -1,93 +0,0 @@
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||
of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
"""
|
||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||
won't add any fields that are not already defined, but other a different metadata service
|
||||
implementation might.
|
||||
"""
|
||||
|
||||
type: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
"""The type of the ancestor node of the image output node."""
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
"""The positive conditioning"""
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
"""The negative conditioning"""
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
"""Width of the image/latents in pixels"""
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
"""Height of the image/latents in pixels"""
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
"""The number of steps used for inference"""
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
"""The scheduler used for inference"""
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
"""The model used for inference"""
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
"""The strength used for image-to-image/latents-to-latents."""
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
"""The ID of the initial latents"""
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
"""The VAE used for decoding"""
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
"""The UNet used dor inference"""
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
"""The CLIP Encoder used for conditioning"""
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
||||
@@ -1,581 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||
categories returned by `invokeai --help`. The file looks like this:
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
embedding_dir: embeddings
|
||||
lora_dir: loras
|
||||
autoconvert_dir: null
|
||||
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_loaded_models: 4
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
nsfw_checker: true
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
OmegaConf dictionary object initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=omegaconf)
|
||||
|
||||
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# False
|
||||
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<setting>", as in:
|
||||
|
||||
export INVOKEAI_port=8080
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage at the top level file:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
In most cases, you will want to create a single InvokeAIAppConfig
|
||||
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||
does this:
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args() # read values from the command line/config file
|
||||
print(config.root)
|
||||
|
||||
# Subclassing
|
||||
|
||||
If you wish to create a similar class, please subclass the
|
||||
`InvokeAISettings` class and define a Literal field named "type",
|
||||
which is set to the desired top-level name. For example, to create a
|
||||
"InvokeBatch" configuration, define like this:
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||
two configs are kept in separate sections of the config file:
|
||||
|
||||
# invokeai.yaml
|
||||
|
||||
InvokeBatch:
|
||||
Resources:
|
||||
node_count: 1
|
||||
cpu_count: 8
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
def to_yaml(self)->str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
if name in cls._excluded():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
|
||||
#fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
||||
#fmt: on
|
||||
|
||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||
'''
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
'''
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||
|
||||
@classmethod
|
||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
if cls.singleton_config is None \
|
||||
or type(cls.singleton_config)!=cls \
|
||||
or (kwargs and cls.singleton_init != kwargs):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return Path(self.root).expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self)->Path:
|
||||
'''
|
||||
Path to invokeai.yaml
|
||||
'''
|
||||
return self._resolve(INIT_FILE)
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self)->Path:
|
||||
'''
|
||||
Path to the invokeai.db file.
|
||||
'''
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def cache_dir(self)->Path:
|
||||
'''
|
||||
Path to the global cache directory for HuggingFace hub-managed models
|
||||
'''
|
||||
return self.models_dir / "hub"
|
||||
|
||||
@property
|
||||
def models_dir(self)->Path:
|
||||
'''
|
||||
Path to the models directory
|
||||
'''
|
||||
return self._resolve("models")
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def controlnet_path(self)->Path:
|
||||
'''
|
||||
Path to the controlnet models directory.
|
||||
'''
|
||||
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
'''
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
@@ -1,64 +0,0 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@@ -13,8 +14,7 @@ class EventServiceBase:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
@@ -25,8 +25,7 @@ class EventServiceBase:
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
@@ -36,60 +35,48 @@ class EventServiceBase:
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
invocation_id=invocation_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error: str,
|
||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
invocation_id=invocation_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
self, graph_execution_state_id: str, invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
invocation_id=invocation_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -97,7 +84,5 @@ class EventServiceBase:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
@@ -16,7 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -25,6 +26,7 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@@ -60,33 +62,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
else:
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
# If the type is a List
|
||||
if get_origin(t) is list:
|
||||
return True
|
||||
|
||||
# If the type is a Union
|
||||
elif t_args:
|
||||
# Check if any of the types in the Union is a List
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
@@ -112,8 +87,7 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
if not issubclass(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -153,17 +127,9 @@ class NodeAlreadyExecutedError(Exception):
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["graph_output"] = "graph_output"
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@@ -181,17 +147,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'item',
|
||||
]
|
||||
}
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
@@ -211,13 +169,6 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'collection',
|
||||
]
|
||||
}
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
@@ -243,7 +194,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -311,8 +262,7 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@@ -383,7 +333,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _validate_edge(self, edge: Edge):
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@@ -391,53 +341,54 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@@ -722,11 +673,7 @@ class Graph(BaseModel):
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
@@ -745,13 +692,6 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
@@ -772,7 +712,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@@ -789,7 +729,9 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
id: str = Field(
|
||||
description="The id of the execution state", default_factory=uuid.uuid4
|
||||
)
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
@@ -831,6 +773,9 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
@@ -857,9 +802,11 @@ class GraphExecutionState(BaseModel):
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# Prepare as many nodes as we can
|
||||
while prepared_id is not None:
|
||||
prepared_id = self._prepare()
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
@@ -893,8 +840,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@@ -982,11 +928,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph_flat()
|
||||
g = self.graph.nx_graph()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@@ -998,7 +944,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@@ -1006,30 +952,14 @@ class GraphExecutionState(BaseModel):
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next node that:
|
||||
# - was not already prepared
|
||||
# - is not an iterate node whose inputs have not been executed
|
||||
# - does not have an unexecuted iterate ancestor
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
# exclude nodes that have already been prepared
|
||||
if n not in self.source_prepared_mapping
|
||||
# exclude iterate nodes whose inputs have not been executed
|
||||
and not (
|
||||
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||
)
|
||||
# exclude nodes who have unexecuted iterate ancestors
|
||||
and not any(
|
||||
(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||
and a not in self.executed # ...that is not executed
|
||||
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||
)
|
||||
)
|
||||
and all((e[0] in self.executed for e in g.in_edges(n)))
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -1118,30 +1048,18 @@ class GraphExecutionState(BaseModel):
|
||||
n
|
||||
for n in prepared_nodes
|
||||
if all(
|
||||
nx.has_path(execution_graph, pit[0], n)
|
||||
pit
|
||||
for pit in parent_iterators
|
||||
if nx.has_path(execution_graph, pit[0], n)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
@@ -1163,9 +1081,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@@ -1211,52 +1127,4 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||
# 500 internal server error. I don't like having this method on the service.
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
@@ -1,419 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(limit)
|
||||
images_params.append(offset)
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
113
invokeai/app/services/image_storage.py
Normal file
113
invokeai/app/services/image_storage.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
||||
@@ -1,393 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
"""Updates an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_origin, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
session = self._services.graph_execution_manager.get(session_id)
|
||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||
|
||||
return metadata
|
||||
@@ -1,17 +1,30 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import time
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
|
||||
@@ -1,60 +1,39 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.restoration_services import RestorationServices
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.config import InvokeAISettings
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
events: "EventServiceBase"
|
||||
latents: "LatentsStorageBase"
|
||||
queue: "InvocationQueueABC"
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: "ModelManager",
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageService",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: "RestorationServices",
|
||||
configuration: "InvokeAISettings",
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
||||
@@ -22,8 +22,7 @@ class Invoker:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
@@ -34,6 +33,7 @@ class Invoker:
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f"queueing item {invocation.id}")
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
@@ -50,7 +50,7 @@ class Invoker:
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
@@ -72,12 +72,18 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
"""Handles building metadata for nodes, images, and outputs."""
|
||||
|
||||
@abstractmethod
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""Builds an ImageMetadata object for a node."""
|
||||
pass
|
||||
|
||||
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
|
||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||
"""The core metadata parameters in the ancestor types"""
|
||||
|
||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||
"""The core metadata parameters in the noise node"""
|
||||
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
metadata = self._build_metadata_from_graph(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
Parameters:
|
||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||
have the same data as the execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, return its id
|
||||
if node.get("type") in self._ANCESTOR_TYPES:
|
||||
return node.get("id")
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = self._find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Union[dict[str, Any], None]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph (Graph): The execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: A dictionary of additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# Prompt
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# Negative prompt
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# Seed, width and height
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
return metadata
|
||||
|
||||
def _build_metadata_from_graph(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""
|
||||
Builds an ImageMetadata object for a node.
|
||||
|
||||
Parameters:
|
||||
session (GraphExecutionState): The session.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# We need to do all the traversal on the execution graph
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
if ancestor_id is None:
|
||||
return ImageMetadata()
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# Grab all the core metadata from the ancestor node
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# Get this image's prompts and noise parameters
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
ancestor_metadata.update(addl_metadata)
|
||||
|
||||
return ImageMetadata(**ancestor_metadata)
|
||||
@@ -2,25 +2,26 @@ import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from .config import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
|
||||
)
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{config.root}"')
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@@ -30,7 +31,20 @@ def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> Mod
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
embedding_path = config.embedding_path
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
@@ -43,36 +57,37 @@ def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> Mod
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.model_conf_path),
|
||||
OmegaConf.load(config.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = embedding_path,
|
||||
logger = logger,
|
||||
embedding_path = Path(embedding_path),
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e, logger)
|
||||
report_model_error(config, e)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if config.autoconvert_path:
|
||||
model_manager.heuristic_import(
|
||||
config.autoconvert_path,
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
@@ -81,12 +96,13 @@ def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
print("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The session ID that generated this image, if it is a generated image."""
|
||||
node_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None,
|
||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||
)
|
||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
description="The image's new category."
|
||||
)
|
||||
"""The image's new category."""
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
node_id = image_dict.get("node_id", None)
|
||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
if raw_metadata is not None:
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
@@ -1,22 +1,17 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
from threading import Event, Thread
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
@@ -25,7 +20,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: make async and do not use threads
|
||||
True # TODO: probably better to just not use threads?
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
@@ -34,16 +29,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
@@ -55,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
invocation_id=invocation.id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
@@ -91,17 +75,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
invocation_id=invocation.id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
@@ -116,13 +96,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
invocation_id=invocation.id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
@@ -132,22 +111,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
... # Log something?
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
import uuid
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
||||
@@ -1,7 +1,6 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@@ -11,7 +10,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
def __init__(self,args):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@@ -21,22 +20,20 @@ class RestorationServices:
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
print(">> Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
print(">> Upscaling disabled")
|
||||
else:
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@@ -61,15 +58,15 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -83,7 +80,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
print(">> Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -96,10 +93,10 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
|
||||
@@ -26,6 +26,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
@@ -58,7 +59,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
@@ -84,7 +84,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
@@ -1,15 +0,0 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support additional features in Enums.
|
||||
|
||||
- `in` operator support: `'value' in MyEnum -> bool`
|
||||
"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
@@ -1,21 +0,0 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
@@ -1,55 +0,0 @@
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be. Use
|
||||
# that estimate if it is available.
|
||||
if intermediate_state.predicted_original is not None:
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_thumbnail_name(image_name: str) -> str:
|
||||
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||
return thumbnail_name
|
||||
|
||||
|
||||
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
||||
"""Makes a thumbnail from a PIL Image"""
|
||||
thumbnail = image.copy()
|
||||
thumbnail.thumbnail(size=(size, size))
|
||||
return thumbnail
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
@@ -9,5 +10,7 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .model_management import ModelManager
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
||||
1387
invokeai/backend/args.py
Normal file
1387
invokeai/backend/args.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,17 +12,17 @@ print("Loading Python libraries...\n",file=sys.stderr)
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import HfFolder
|
||||
@@ -35,53 +35,57 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import (
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from ...frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
)
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
||||
from .model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
UserSelections,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 135
|
||||
MIN_LINES = 45
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
# Place frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --outdir=D:\data\images
|
||||
# --no-nsfw_checker
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
logger=None
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@@ -92,13 +96,14 @@ If you installed manually from source or with 'pip install': activate the virtua
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai-web
|
||||
invokeai --web # (connect to http://localhost:9090)
|
||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
||||
|
||||
Command-line client:
|
||||
Command-line interface:
|
||||
invokeai
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
@@ -210,11 +215,16 @@ def download_realesrgan():
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
@@ -233,8 +243,8 @@ def download_gfpgan():
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], config.root_path / model[1]
|
||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -243,8 +253,8 @@ def download_codeformer():
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = config.root_path / "models/codeformer/codeformer.pth"
|
||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -285,7 +295,7 @@ def download_vaes():
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=config.cache_dir,
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
@@ -296,7 +306,7 @@ def download_vaes():
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(config.root_path / Model_dir / Weights_dir),
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
@@ -311,24 +321,25 @@ def get_root(root: str = None) -> str:
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return str(config.root_path)
|
||||
return Globals.root
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
class editOptsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in [
|
||||
"Configure startup settings. You can come back and change these later.",
|
||||
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
|
||||
"Use cursor arrows to make a checkbox selection, and space to toggle.",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -355,7 +366,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_output_dir()),
|
||||
value=old_opts.outdir or str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
@@ -370,21 +381,22 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nsfw_checker = self.add_widget_intelligent(
|
||||
self.safety_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.nsfw_checker,
|
||||
value=old_opts.safety_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """If you have an account at HuggingFace you may optionally paste your access token here
|
||||
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
for i in [
|
||||
"If you have an account at HuggingFace you may paste your access token here",
|
||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
||||
"See https://huggingface.co/settings/tokens",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=line,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
@@ -423,10 +435,17 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
self.xformers = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
value=old_opts.xformers_enabled,
|
||||
value=old_opts.xformers,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.ckpt_convert = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Load legacy checkpoint models into memory as diffusers models",
|
||||
value=old_opts.ckpt_convert,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
@@ -461,41 +480,19 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||
value="Directory containing embedding/textual inversion files:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.embedding_dir = self.add_widget_intelligent(
|
||||
self.embedding_path = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" Textual Inversion Embeddings:",
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_embedding_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lora_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" LoRA and LyCORIS:",
|
||||
value=str(default_lora_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.controlnet_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" ControlNets:",
|
||||
value=str(default_controlnet_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
begin_entry_at=40,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -508,11 +505,11 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
for i in [
|
||||
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
|
||||
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
|
||||
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -551,7 +548,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
@@ -562,9 +559,9 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if not Path(opt.embedding_dir).parent.exists():
|
||||
if not Path(opt.embedding_path).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
@@ -579,24 +576,20 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"embedding_dir",
|
||||
"lora_dir",
|
||||
"controlnet_dir",
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
@@ -615,7 +608,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
cycle_widgets=True,
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
@@ -623,7 +615,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
cycle_widgets=True,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
@@ -637,14 +628,18 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts = Args().parse_args([])
|
||||
outdir = Path(opts.outdir)
|
||||
if not outdir.is_absolute():
|
||||
opts.outdir = str(Globals.root / opts.outdir)
|
||||
if not init_file.exists():
|
||||
opts.nsfw_checker = True
|
||||
opts.safety_checker = True
|
||||
return opts
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> UserSelections:
|
||||
return UserSelections(
|
||||
install_models=default_dataset()
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
return Namespace(
|
||||
starter_models=default_dataset()
|
||||
if program_opts.default_only
|
||||
else recommended_datasets()
|
||||
if program_opts.yes_to_all
|
||||
@@ -652,27 +647,26 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
def initialize_rootdir(root: str, yes_to_all: bool = False):
|
||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
|
||||
for name in (
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"databases",
|
||||
"loras",
|
||||
"controlnets",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = root / "configs"
|
||||
configs_dest = Path(root) / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
@@ -683,17 +677,8 @@ def run_console_ui(
|
||||
) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
@@ -705,66 +690,70 @@ def run_console_ui(
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
Update the invokeai.init file with values from opts Namespace
|
||||
"""
|
||||
# this will load current settings
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(new_config,key):
|
||||
setattr(new_config,key,value)
|
||||
# touch file if it doesn't exist
|
||||
if not init_file.exists():
|
||||
with open(init_file, "w") as f:
|
||||
f.write(INIT_FILE_PREAMBLE)
|
||||
|
||||
# We want to write in the changed arguments without clobbering
|
||||
# any other initialization values the user has entered. There is
|
||||
# no good way to do this because of the one-way nature of
|
||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
||||
# initfile needs to be replaced with a fully structured format
|
||||
# such as yaml; this is a hack that will work much of the time
|
||||
args_to_skip = re.compile(
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace("\\", "/")
|
||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
with open(new_file, "w") as out_file:
|
||||
for line in lines:
|
||||
if len(line) > 0 and not args_to_skip.match(line):
|
||||
out_file.write(line + "\n")
|
||||
out_file.write(
|
||||
f"""
|
||||
--outdir={opts.outdir}
|
||||
--embedding_path={opts.embedding_path}
|
||||
--precision={opts.precision}
|
||||
--max_loaded_models={int(opts.max_loaded_models)}
|
||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
||||
--{'no-' if not opts.xformers else ''}xformers
|
||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
||||
"""
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
||||
|
||||
os.replace(new_file, init_file)
|
||||
|
||||
if opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
file.write(new_config.to_yaml())
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return config.root_path / "outputs"
|
||||
return Globals.root / "outputs"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_embedding_dir() -> Path:
|
||||
return config.root_path / "embeddings"
|
||||
return Globals.root / "embeddings"
|
||||
|
||||
# -------------------------------------
|
||||
def default_lora_dir() -> Path:
|
||||
return config.root_path / "loras"
|
||||
|
||||
# -------------------------------------
|
||||
def default_controlnet_dir() -> Path:
|
||||
return config.root_path / "controlnets"
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
new.nsfw_checker = old.safety_checker
|
||||
new.xformers_enabled = old.xformers
|
||||
new.conf_path = old.conf
|
||||
new.embedding_dir = old.embedding_path
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
@@ -820,13 +809,8 @@ def main():
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
invoke_args = []
|
||||
if opt.root:
|
||||
invoke_args.extend(['--root',opt.root])
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
# setting a global here
|
||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
|
||||
errors = set()
|
||||
|
||||
@@ -834,26 +818,19 @@ def main():
|
||||
models_to_download = default_user_selections(opt)
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = config.root_path / 'invokeai.init'
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
# Load new init file into config
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
init_file = Path(Globals.root, Globals.initfile)
|
||||
if not init_file.exists() or not global_config_file().exists():
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
write_default_options(opt, init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
write_opts(init_options, init_file)
|
||||
else:
|
||||
print(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
@@ -863,7 +840,7 @@ def main():
|
||||
if opt.skip_support_models:
|
||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
||||
else:
|
||||
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
@@ -881,8 +858,6 @@ def main():
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input('Press any key to continue...')
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
@@ -6,30 +6,26 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List, Dict, Callable
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url, HfFolder
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_management import ModelManager
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
@@ -39,9 +35,6 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
# logger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
@@ -52,98 +45,59 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class UserSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
purge_deleted_models: bool=field(default_factory=list)
|
||||
install_cn_models: List[str] = field(default_factory=list)
|
||||
remove_cn_models: List[str] = field(default_factory=list)
|
||||
install_lora_models: List[str] = field(default_factory=list)
|
||||
remove_lora_models: List[str] = field(default_factory=list)
|
||||
install_ti_models: List[str] = field(default_factory=list)
|
||||
remove_ti_models: List[str] = field(default_factory=list)
|
||||
scan_directory: Path = None
|
||||
autoscan_on_startup: bool=False
|
||||
import_model_paths: str=None
|
||||
|
||||
def default_config_file():
|
||||
return config.model_conf_path
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
|
||||
def sd_configs():
|
||||
return config.legacy_conf_path
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
|
||||
def install_requested_models(
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
lora: ModelInstallList = None,
|
||||
ti: ModelInstallList = None,
|
||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
model_config_file_callback: Callable[[Path],Path] = None
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
"""
|
||||
access_token = HfFolder.get_token()
|
||||
config_file_path = config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
if controlnet:
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if lora:
|
||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if ti:
|
||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
||||
model_manager.delete_ti_models(ti.remove_models)
|
||||
|
||||
if diffusers:
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("Processing requested deletions")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("Installing requested models")
|
||||
downloaded_paths = download_weight_datasets(
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
)
|
||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
if len(successful) > 0:
|
||||
update_config_file(successful, config_file_path)
|
||||
if len(successful) < len(diffusers.install_models):
|
||||
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
||||
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@@ -154,14 +108,13 @@ def install_requested_models(
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models) > 0:
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path,
|
||||
config_file_callback = model_config_file_callback,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
@@ -169,22 +122,17 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
update_autoconvert_dir(scan_directory)
|
||||
else:
|
||||
update_autoconvert_dir(None)
|
||||
|
||||
def update_autoconvert_dir(autodir: Path):
|
||||
'''
|
||||
Update the "autoconvert_dir" option in invokeai.yaml
|
||||
'''
|
||||
invokeai_config_path = config.init_file_path
|
||||
conf = OmegaConf.load(invokeai_config_path)
|
||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
|
||||
yaml = OmegaConf.to_yaml(conf)
|
||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml)
|
||||
tmpfile.replace(invokeai_config_path)
|
||||
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f"{argument} {directory}"])
|
||||
os.replace(replacement, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@@ -196,21 +144,33 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> List['str']:
|
||||
datasets = set()
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("recommended", False):
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = set()
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("default", False):
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -225,14 +185,14 @@ def all_datasets() -> dict:
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
logger.warning(
|
||||
print(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
logger.warning(f"model.ckpt => {new_name}")
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
@@ -245,7 +205,7 @@ def download_weight_datasets(
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
logger.info(f"Downloading {mod}:")
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
@@ -266,10 +226,11 @@ def _download_repo_or_file(
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
@@ -280,12 +241,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, **kwargs
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||
|
||||
path = config.cache_dir
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
@@ -316,10 +274,10 @@ def _download_diffusion_weights(
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if 'Revision Not Found' in str(e):
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
logger.error(str(e))
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
@@ -327,13 +285,9 @@ def _download_diffusion_weights(
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
@@ -353,19 +307,20 @@ def hf_download_with_resume(
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
@@ -378,7 +333,7 @@ def hf_download_with_resume(
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
@@ -403,8 +358,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
logger.warning(
|
||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
@@ -421,16 +376,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
logger.info("restoring previous config file")
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
logger.info(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -464,7 +419,7 @@ def new_config_file_contents(
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=config.root_dir
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(
|
||||
os.path.join(sd_configs(), mod["config"])
|
||||
@@ -497,14 +452,14 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = config.root_dir / weights
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
logger.error(str(e))
|
||||
print(str(e))
|
||||
1249
invokeai/backend/generate.py
Normal file
1249
invokeai/backend/generate.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -21,39 +21,37 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import Callable, List, Iterator, Optional, Type
|
||||
from typing import List, Iterator, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: Optional[int]=None
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: float=7.5
|
||||
cfg_scale: int=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
scheduler: str='ddim'
|
||||
scheduler: int='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: float=0.0
|
||||
threshold: int=0.0
|
||||
seamless: bool=False
|
||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||
h_symmetry_time_pct: Optional[float]=None
|
||||
v_symmetry_time_pct: Optional[float]=None
|
||||
h_symmetry_time_pct: float=None
|
||||
v_symmetry_time_pct: float=None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list=field(default_factory=list)
|
||||
safety_checker: Optional[SafetyChecker]=None
|
||||
safety_checker: SafetyChecker=None
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
@@ -63,28 +61,39 @@ class InvokeAIGeneratorOutput:
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
image: Image.Image
|
||||
image: Image
|
||||
seed: int
|
||||
model_hash: str
|
||||
attention_maps_images: List[Image.Image]
|
||||
attention_maps_images: List[Image]
|
||||
params: Namespace
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
callback: callable=None,
|
||||
step_callback: callable=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
@@ -120,12 +129,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
|
||||
# get conditioning from prompt via Compel package
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
generator = gen_class(model, self.params.precision)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
@@ -148,7 +154,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
@@ -168,20 +173,14 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
return list(self.scheduler_map.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -206,10 +205,10 @@ class Txt2Img(InvokeAIGenerator):
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
init_image: Image | torch.FloatTensor,
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(init_image=init_image,
|
||||
strength=strength,
|
||||
**keyword_args
|
||||
@@ -223,12 +222,12 @@ class Img2Img(InvokeAIGenerator):
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 30,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
@@ -236,7 +235,7 @@ class Inpaint(Img2Img):
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(
|
||||
mask_image=mask_image,
|
||||
seam_size=seam_size,
|
||||
@@ -263,7 +262,7 @@ class Embiggen(Txt2Img):
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
@@ -281,7 +280,7 @@ class Generator:
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
@@ -372,7 +371,7 @@ class Generator:
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
logger.error("An error occurred while getting initial noise")
|
||||
print("** An error occurred while getting initial noise **")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
@@ -607,7 +606,7 @@ class Generator:
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
logger.info(f"creating directory {dirname}")
|
||||
print(f"** creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user