Compare commits
22 Commits
kyle0654/c
...
packaging/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76adcc122b | ||
|
|
650f4bb58c | ||
|
|
7b92b27ceb | ||
|
|
8f1b301d01 | ||
|
|
e3a19d4f3e | ||
|
|
ecbb385447 | ||
|
|
741464b053 | ||
|
|
33f832e6ab | ||
|
|
281c788489 | ||
|
|
3858bef185 | ||
|
|
f9a1afd09c | ||
|
|
251e9c0294 | ||
|
|
d8bf2e3c10 | ||
|
|
218f30b7d0 | ||
|
|
da983c7773 | ||
|
|
7012e16c43 | ||
|
|
b1050abf7f | ||
|
|
210998081a | ||
|
|
604acb9d91 | ||
|
|
5beeb1a897 | ||
|
|
de6304b729 | ||
|
|
d0be79c33d |
@@ -1,6 +0,0 @@
|
||||
[run]
|
||||
omit='.env/*'
|
||||
source='.'
|
||||
|
||||
[report]
|
||||
show_missing = true
|
||||
@@ -4,22 +4,22 @@
|
||||
!ldm
|
||||
!pyproject.toml
|
||||
|
||||
# ignore frontend/web but whitelist dist
|
||||
invokeai/frontend/web/
|
||||
!invokeai/frontend/web/dist/
|
||||
# Guard against pulling in any models that might exist in the directory tree
|
||||
**/*.pt*
|
||||
**/*.ckpt
|
||||
|
||||
# ignore frontend but whitelist dist
|
||||
invokeai/frontend/
|
||||
!invokeai/frontend/dist/
|
||||
|
||||
# ignore invokeai/assets but whitelist invokeai/assets/web
|
||||
invokeai/assets/
|
||||
!invokeai/assets/web/
|
||||
|
||||
# Guard against pulling in any models that might exist in the directory tree
|
||||
**/*.pt*
|
||||
**/*.ckpt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
**/__pycache__/
|
||||
**/*.py[cod]
|
||||
|
||||
# Distribution / packaging
|
||||
**/*.egg-info/
|
||||
**/*.egg
|
||||
*.egg-info/
|
||||
*.egg
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
59
.github/CODEOWNERS
vendored
@@ -1,34 +1,51 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte
|
||||
/mkdocs.yml @lstein @mauwii
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
/invokeai/configs @lstein
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
/pyproject.toml @mauwii @lstein @ebr @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein @tildebyte @blessedcoolant
|
||||
ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
invokeai/assets @lstein @ebr @blessedcoolant
|
||||
invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/_version.py @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn @blessedcoolant
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815 @blessedcoolant
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
/ldm/invoke/txt2mask.py @lstein @blessedcoolant
|
||||
/ldm/invoke/patchmatch.py @Kyle0654 @blessedcoolant @lstein
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn @lstein @blessedcoolant
|
||||
/ldm/modules @damian0815 @keturn @lstein @blessedcoolant
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654 @lstein @blessedcoolant
|
||||
|
||||
# legacy REST API
|
||||
# is CapableWeb still engaged?
|
||||
/ldm/invoke/pngwriter.py @CapableWeb @lstein @blessedcoolant
|
||||
/ldm/invoke/server_legacy.py @CapableWeb @lstein @blessedcoolant
|
||||
/scripts/legacy_api.py @CapableWeb @lstein @blessedcoolant
|
||||
/tests/legacy_tests.sh @CapableWeb @lstein @blessedcoolant
|
||||
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -65,16 +65,6 @@ body:
|
||||
placeholder: 8GB
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: version-number
|
||||
attributes:
|
||||
label: What version did you experience this issue on?
|
||||
description: |
|
||||
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: X.X.X
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
|
||||
21
.github/workflows/build-container.yml
vendored
@@ -5,17 +5,18 @@ on:
|
||||
- 'main'
|
||||
- 'update/ci/docker/*'
|
||||
- 'update/docker/*'
|
||||
- 'dev/ci/docker/*'
|
||||
- 'dev/docker/*'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- '.dockerignore'
|
||||
- 'invokeai/**'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
- 'docker/Dockerfile'
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
if: github.event.pull_request.draft == false
|
||||
@@ -23,11 +24,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flavor:
|
||||
- rocm
|
||||
- amd
|
||||
- cuda
|
||||
- cpu
|
||||
include:
|
||||
- flavor: rocm
|
||||
- flavor: amd
|
||||
pip-extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
- flavor: cuda
|
||||
pip-extra-index-url: ''
|
||||
@@ -53,9 +54,9 @@ jobs:
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=pep440,pattern={{version}}
|
||||
type=pep440,pattern={{major}}.{{minor}}
|
||||
type=pep440,pattern={{major}}
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,enable=true,prefix=sha-,format=short
|
||||
flavor: |
|
||||
latest=${{ matrix.flavor == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
@@ -91,7 +92,7 @@ jobs:
|
||||
context: .
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: PIP_EXTRA_INDEX_URL=${{ matrix.pip-extra-index-url }}
|
||||
|
||||
27
.github/workflows/close-inactive-issues.yml
vendored
@@ -1,27 +0,0 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "00 6 * * *"
|
||||
|
||||
env:
|
||||
DAYS_BEFORE_ISSUE_STALE: 14
|
||||
DAYS_BEFORE_ISSUE_CLOSE: 28
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
|
||||
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
|
||||
stale-issue-label: "Inactive Issue"
|
||||
stale-issue-message: "There has been no activity in this issue for ${{ env.DAYS_BEFORE_ISSUE_STALE }} days. If this issue is still being experienced, please reply with an updated confirmation that the issue is still being experienced with the latest release."
|
||||
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 500
|
||||
22
.github/workflows/lint-frontend.yml
vendored
@@ -3,22 +3,14 @@ name: Lint frontend
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
- 'invokeai/frontend/**'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
- 'invokeai/frontend/**'
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
working-directory: invokeai/frontend
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
@@ -31,7 +23,7 @@ jobs:
|
||||
node-version: '18'
|
||||
- uses: actions/checkout@v3
|
||||
- run: 'yarn install --frozen-lockfile'
|
||||
- run: 'yarn run lint:tsc'
|
||||
- run: 'yarn run lint:madge'
|
||||
- run: 'yarn run lint:eslint'
|
||||
- run: 'yarn run lint:prettier'
|
||||
- run: 'yarn tsc'
|
||||
- run: 'yarn run madge'
|
||||
- run: 'yarn run lint --max-warnings=0'
|
||||
- run: 'yarn run prettier --check'
|
||||
|
||||
2
.github/workflows/pypi-release.yml
vendored
@@ -3,7 +3,7 @@ name: PyPI Release
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'invokeai/version/invokeai_version.py'
|
||||
- 'ldm/invoke/_version.py'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
12
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,12 +1,12 @@
|
||||
name: Test invoke.py pip
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**'
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
- '!invokeai/frontend/web/dist/**'
|
||||
paths-ignore:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
16
.github/workflows/test-invoke-pip.yml
vendored
@@ -5,15 +5,17 @@ on:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -110,7 +112,7 @@ jobs:
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
"import os;from ldm.invoke.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
|
||||
11
.gitignore
vendored
@@ -68,7 +68,6 @@ htmlcov/
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
cov.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
@@ -198,7 +197,7 @@ checkpoints
|
||||
.DS_Store
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
!invokeai/frontend/*
|
||||
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
@@ -213,6 +212,11 @@ gfpgan/
|
||||
# config file (will be created by installer)
|
||||
configs/models.yaml
|
||||
|
||||
# weights (will be created by installer)
|
||||
models/ldm/stable-diffusion-v1/*.ckpt
|
||||
models/clipseg
|
||||
models/gfpgan
|
||||
|
||||
# ignore initfile
|
||||
.invokeai
|
||||
|
||||
@@ -227,3 +231,6 @@ installer/install.bat
|
||||
installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = webtas.settings
|
||||
; python_files = tests.py test_*.py *_tests.py
|
||||
|
||||
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml
|
||||
@@ -1,164 +0,0 @@
|
||||
@echo off
|
||||
|
||||
@rem This script will install git (if not found on the PATH variable)
|
||||
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
@rem For users who already have git, this step will be skipped.
|
||||
|
||||
@rem Next, it'll download the project's source code.
|
||||
@rem Then it will download a self-contained, standalone Python and unpack it.
|
||||
@rem Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
@rem This enables a user to install this project without manually installing git or Python
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
echo ***** Installing InvokeAI.. *****
|
||||
@rem Config
|
||||
set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
||||
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
||||
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
||||
|
||||
set PACKAGES_TO_INSTALL=
|
||||
|
||||
call git --version >.tmp1 2>.tmp2
|
||||
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem (if necessary) install git into a contained environment
|
||||
if "%PACKAGES_TO_INSTALL%" NEQ "" (
|
||||
@rem download micromamba
|
||||
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
|
||||
|
||||
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
|
||||
|
||||
@rem test the mamba binary
|
||||
echo ***** Micromamba version: *****
|
||||
call micromamba.exe --version
|
||||
|
||||
@rem create the installer env
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
|
||||
)
|
||||
|
||||
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
|
||||
|
||||
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
|
||||
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
del /q micromamba.exe
|
||||
|
||||
@rem For 'git' only
|
||||
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
|
||||
|
||||
@rem Download/unpack/clean up InvokeAI release sourceball
|
||||
set err_msg=----- InvokeAI source download failed -----
|
||||
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
|
||||
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- InvokeAI source unpack failed -----
|
||||
tar -zxf InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q InvokeAI.tgz
|
||||
|
||||
set err_msg=----- InvokeAI source copy failed -----
|
||||
cd InvokeAI-*
|
||||
xcopy . .. /e /h
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
cd ..
|
||||
|
||||
@rem cleanup
|
||||
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
|
||||
rd /s /q .dev_scripts .github docker-build tests
|
||||
del /q requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo ***** Unpacked InvokeAI source *****
|
||||
|
||||
@rem Download/unpack/clean up python-build-standalone
|
||||
set err_msg=----- Python download failed -----
|
||||
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- Python unpack failed -----
|
||||
tar -zxf python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q python.tgz
|
||||
|
||||
echo ***** Unpacked python-build-standalone *****
|
||||
|
||||
@rem create venv
|
||||
set err_msg=----- problem creating venv -----
|
||||
.\python\python -E -s -m venv .venv
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo ***** Created Python virtual environment *****
|
||||
|
||||
@rem Print venv's Python version
|
||||
set err_msg=----- problem calling venv's python -----
|
||||
echo We're running under
|
||||
.venv\Scripts\python --version
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- pip update failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Updated pip and wheel *****
|
||||
|
||||
set err_msg=----- requirements file copy failed -----
|
||||
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- main pip install failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Installed Python dependencies *****
|
||||
|
||||
set err_msg=----- InvokeAI setup failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
copy binary_installer\invoke.bat.in .\invoke.bat
|
||||
echo ***** Installed invoke launcher script ******
|
||||
|
||||
@rem more cleanup
|
||||
rd /s /q binary_installer installer_files
|
||||
|
||||
@rem preload the models
|
||||
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
|
||||
set err_msg=----- model download clone failed -----
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
deactivate
|
||||
|
||||
echo ***** Finished downloading models *****
|
||||
|
||||
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
|
||||
pause
|
||||
exit
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
pause
|
||||
exit
|
||||
@@ -1,235 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
set -euo pipefail
|
||||
IFS=$'\n\t'
|
||||
|
||||
function _err_exit {
|
||||
if test "$1" -ne 0
|
||||
then
|
||||
echo -e "Error code $1; Error caught was '$2'"
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
# This script will install git (if not found on the PATH variable)
|
||||
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
# For users who already have git, this step will be skipped.
|
||||
|
||||
# Next, it'll download the project's source code.
|
||||
# Then it will download a self-contained, standalone Python and unpack it.
|
||||
# Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
# This enables a user to install this project without manually installing git or Python
|
||||
|
||||
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
|
||||
|
||||
export no_cache_dir="--no-cache-dir"
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ "$1" = "use-cache" ]; then
|
||||
export no_cache_dir=""
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
OS_NAME=$(uname -s)
|
||||
case "${OS_NAME}" in
|
||||
Linux*) OS_NAME="linux";;
|
||||
Darwin*) OS_NAME="darwin";;
|
||||
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
|
||||
esac
|
||||
|
||||
OS_ARCH=$(uname -m)
|
||||
case "${OS_ARCH}" in
|
||||
x86_64*) ;;
|
||||
arm64*) ;;
|
||||
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
|
||||
esac
|
||||
|
||||
# https://mamba.readthedocs.io/en/latest/installation.html
|
||||
MAMBA_OS_NAME=$OS_NAME
|
||||
MAMBA_ARCH=$OS_ARCH
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
MAMBA_OS_NAME="osx"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "linux" ]; then
|
||||
MAMBA_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "x86_64" ]; then
|
||||
MAMBA_ARCH="64"
|
||||
fi
|
||||
|
||||
PY_ARCH=$OS_ARCH
|
||||
if [ "$OS_ARCH" == "arm64" ]; then
|
||||
PY_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
# Compute device ('cd' segment of reqs files) detect goes here
|
||||
# This needs a ton of work
|
||||
# Suggestions:
|
||||
# - lspci
|
||||
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
|
||||
# - Surely there's a similar utility for AMD?
|
||||
CD="cuda"
|
||||
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
|
||||
CD="mps"
|
||||
fi
|
||||
|
||||
# config
|
||||
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
||||
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
||||
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
||||
elif [ "$OS_NAME" == "linux" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
|
||||
fi
|
||||
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
|
||||
|
||||
PACKAGES_TO_INSTALL=""
|
||||
|
||||
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
|
||||
|
||||
# (if necessary) install git and conda into a contained environment
|
||||
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
|
||||
# download micromamba
|
||||
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
|
||||
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
|
||||
|
||||
chmod u+x ./micromamba
|
||||
|
||||
# test the mamba binary
|
||||
echo -e "\n***** Micromamba version: *****\n"
|
||||
./micromamba --version
|
||||
|
||||
# create the installer env
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
|
||||
fi
|
||||
|
||||
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
|
||||
|
||||
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
|
||||
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
|
||||
exit
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -f micromamba.exe
|
||||
|
||||
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
|
||||
|
||||
# Download/unpack/clean up InvokeAI release sourceball
|
||||
_err_msg="\n----- InvokeAI source download failed -----\n"
|
||||
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- InvokeAI source unpack failed -----\n"
|
||||
tar -zxf InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f InvokeAI.tgz
|
||||
|
||||
_err_msg="\n----- InvokeAI source copy failed -----\n"
|
||||
cd InvokeAI-*
|
||||
cp -r . ..
|
||||
_err_exit $? _err_msg
|
||||
cd ..
|
||||
|
||||
# cleanup
|
||||
rm -rf InvokeAI-*/
|
||||
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo -e "\n***** Unpacked InvokeAI source *****\n"
|
||||
|
||||
# Download/unpack/clean up python-build-standalone
|
||||
_err_msg="\n----- Python download failed -----\n"
|
||||
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- Python unpack failed -----\n"
|
||||
tar -zxf python.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f python.tgz
|
||||
|
||||
echo -e "\n***** Unpacked python-build-standalone *****\n"
|
||||
|
||||
# create venv
|
||||
_err_msg="\n----- problem creating venv -----\n"
|
||||
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
# patch sysconfig so that extensions can build properly
|
||||
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
|
||||
PYTHON_INSTALL_DIR="$(pwd)/python"
|
||||
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
|
||||
TMPFILE="$(mktemp)"
|
||||
chmod +w "${SYSCONFIG}"
|
||||
cp "${SYSCONFIG}" "${TMPFILE}"
|
||||
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
|
||||
rm -f "${TMPFILE}"
|
||||
fi
|
||||
|
||||
./python/bin/python3 -E -s -m venv .venv
|
||||
_err_exit $? _err_msg
|
||||
source .venv/bin/activate
|
||||
|
||||
echo -e "\n***** Created Python virtual environment *****\n"
|
||||
|
||||
# Print venv's Python version
|
||||
_err_msg="\n----- problem calling venv's python -----\n"
|
||||
echo -e "We're running under"
|
||||
.venv/bin/python3 --version
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- pip update failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Updated pip *****\n"
|
||||
|
||||
_err_msg="\n----- requirements file copy failed -----\n"
|
||||
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- main pip install failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed Python dependencies *****\n"
|
||||
|
||||
_err_msg="\n----- InvokeAI setup failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed InvokeAI *****\n"
|
||||
|
||||
cp binary_installer/invoke.sh.in ./invoke.sh
|
||||
chmod a+rx ./invoke.sh
|
||||
echo -e "\n***** Installed invoke launcher script ******\n"
|
||||
|
||||
# more cleanup
|
||||
rm -rf binary_installer/ installer_files/
|
||||
|
||||
# preload the models
|
||||
.venv/bin/python3 scripts/configure_invokeai.py
|
||||
_err_msg="\n----- model download clone failed -----\n"
|
||||
_err_exit $? _err_msg
|
||||
deactivate
|
||||
|
||||
echo -e "\n***** Finished downloading models *****\n"
|
||||
|
||||
echo "All done! Run the command"
|
||||
echo " $scriptdir/invoke.sh"
|
||||
echo "to start InvokeAI."
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
@@ -1,36 +0,0 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo OR
|
||||
echo 3. open the developer console
|
||||
set /p choice="Please enter 1, 2 or 3: "
|
||||
if /i "%choice%" == "1" (
|
||||
echo Starting the InvokeAI command-line.
|
||||
.venv\Scripts\python scripts\invoke.py %*
|
||||
) else if /i "%choice%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI.
|
||||
.venv\Scripts\python scripts\invoke.py --web %*
|
||||
) else if /i "%choice%" == "3" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) else (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
deactivate
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
set -eu
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line"
|
||||
echo "2. browser-based UI"
|
||||
echo "OR"
|
||||
echo "3. open the developer console"
|
||||
echo "Please enter 1, 2, or 3:"
|
||||
read choice
|
||||
|
||||
case $choice in
|
||||
1)
|
||||
printf "\nStarting the InvokeAI command-line..\n";
|
||||
.venv/bin/python scripts/invoke.py $*;
|
||||
;;
|
||||
2)
|
||||
printf "\nStarting the InvokeAI browser-based UI..\n";
|
||||
.venv/bin/python scripts/invoke.py --web $*;
|
||||
;;
|
||||
3)
|
||||
printf "\nDeveloper Console:\n";
|
||||
printf "Python command is:\n\t";
|
||||
which python;
|
||||
printf "Python version is:\n\t";
|
||||
python --version;
|
||||
echo "*************************"
|
||||
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
|
||||
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
|
||||
printf "*************************\n"
|
||||
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
|
||||
/usr/bin/env "$SHELL";
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection";
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
@@ -1,17 +0,0 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Installation on Windows:
|
||||
NOTE: You might need to enable Windows Long Paths. If you're not sure,
|
||||
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
|
||||
file. Note that you will need to have admin privileges in order to
|
||||
do this.
|
||||
|
||||
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
|
||||
|
||||
Installation on Linux and Mac:
|
||||
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
|
||||
|
||||
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
|
||||
file (on Linux/Mac) to start InvokeAI.
|
||||
@@ -1,33 +0,0 @@
|
||||
--prefer-binary
|
||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--trusted-host https://download.pytorch.org
|
||||
accelerate~=0.15
|
||||
albumentations
|
||||
diffusers[torch]~=0.11
|
||||
einops
|
||||
eventlet
|
||||
flask_cors
|
||||
flask_socketio
|
||||
flaskwebgui==1.0.3
|
||||
getpass_asterisk
|
||||
imageio-ffmpeg
|
||||
pyreadline3
|
||||
realesrgan
|
||||
send2trash
|
||||
streamlit
|
||||
taming-transformers-rom1504
|
||||
test-tube
|
||||
torch-fidelity
|
||||
torch==1.12.1 ; platform_system == 'Darwin'
|
||||
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
torchvision==0.13.1 ; platform_system == 'Darwin'
|
||||
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
transformers
|
||||
picklescan
|
||||
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
|
||||
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
|
||||
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
|
||||
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
|
||||
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip
|
||||
@@ -4,15 +4,15 @@ ARG PYTHON_VERSION=3.9
|
||||
##################
|
||||
## base image ##
|
||||
##################
|
||||
FROM --platform=${TARGETPLATFORM} python:${PYTHON_VERSION}-slim AS python-base
|
||||
FROM python:${PYTHON_VERSION}-slim AS python-base
|
||||
|
||||
LABEL org.opencontainers.image.authors="mauwii@outlook.de"
|
||||
|
||||
# Prepare apt for buildkit cache
|
||||
# prepare for buildkit cache
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean \
|
||||
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache
|
||||
|
||||
# Install dependencies
|
||||
# Install necessary packages
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@@ -23,7 +23,7 @@ RUN \
|
||||
libglib2.0-0=2.66.* \
|
||||
libopencv-dev=4.5.*
|
||||
|
||||
# Set working directory and env
|
||||
# set working directory and env
|
||||
ARG APPDIR=/usr/src
|
||||
ARG APPNAME=InvokeAI
|
||||
WORKDIR ${APPDIR}
|
||||
@@ -32,7 +32,7 @@ ENV PATH ${APPDIR}/${APPNAME}/bin:$PATH
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
# Turns off buffering for easier container logging
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
# Don't fall back to legacy build system
|
||||
# don't fall back to legacy build system
|
||||
ENV PIP_USE_PEP517=1
|
||||
|
||||
#######################
|
||||
@@ -40,7 +40,7 @@ ENV PIP_USE_PEP517=1
|
||||
#######################
|
||||
FROM python-base AS pyproject-builder
|
||||
|
||||
# Install build dependencies
|
||||
# Install dependencies
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@@ -51,30 +51,26 @@ RUN \
|
||||
gcc=4:10.2.* \
|
||||
python3-dev=3.9.*
|
||||
|
||||
# Prepare pip for buildkit cache
|
||||
# prepare pip for buildkit cache
|
||||
ARG PIP_CACHE_DIR=/var/cache/buildkit/pip
|
||||
ENV PIP_CACHE_DIR ${PIP_CACHE_DIR}
|
||||
RUN mkdir -p ${PIP_CACHE_DIR}
|
||||
|
||||
# Create virtual environment
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
# create virtual environment
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
||||
python3 -m venv "${APPNAME}" \
|
||||
--upgrade-deps
|
||||
|
||||
# Install requirements
|
||||
COPY --link pyproject.toml .
|
||||
COPY --link invokeai/version/invokeai_version.py invokeai/version/__init__.py invokeai/version/
|
||||
# copy sources
|
||||
COPY --link . .
|
||||
|
||||
# install pyproject.toml
|
||||
ARG PIP_EXTRA_INDEX_URL
|
||||
ENV PIP_EXTRA_INDEX_URL ${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
"${APPNAME}"/bin/pip install .
|
||||
|
||||
# Install pyproject.toml
|
||||
COPY --link . .
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
||||
"${APPNAME}/bin/pip" install .
|
||||
|
||||
# Build patchmatch
|
||||
# build patchmatch
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
|
||||
#####################
|
||||
@@ -90,14 +86,14 @@ RUN useradd \
|
||||
-U \
|
||||
"${UNAME}"
|
||||
|
||||
# Create volume directory
|
||||
# create volume directory
|
||||
ARG VOLUME_DIR=/data
|
||||
RUN mkdir -p "${VOLUME_DIR}" \
|
||||
&& chown -hR "${UNAME}:${UNAME}" "${VOLUME_DIR}"
|
||||
&& chown -R "${UNAME}" "${VOLUME_DIR}"
|
||||
|
||||
# Setup runtime environment
|
||||
USER ${UNAME}:${UNAME}
|
||||
COPY --chown=${UNAME}:${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
|
||||
# setup runtime environment
|
||||
USER ${UNAME}
|
||||
COPY --chown=${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
|
||||
ENV INVOKEAI_ROOT ${VOLUME_DIR}
|
||||
ENV TRANSFORMERS_CACHE ${VOLUME_DIR}/.cache
|
||||
ENV INVOKE_MODEL_RECONFIGURE "--yes --default_only"
|
||||
|
||||
@@ -41,7 +41,7 @@ else
|
||||
fi
|
||||
|
||||
# Build Container
|
||||
docker build \
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--platform="${PLATFORM:-linux/amd64}" \
|
||||
--tag="${CONTAINER_IMAGE:-invokeai}" \
|
||||
${CONTAINER_FLAVOR:+--build-arg="CONTAINER_FLAVOR=${CONTAINER_FLAVOR}"} \
|
||||
|
||||
@@ -49,6 +49,3 @@ CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}"
|
||||
CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}"
|
||||
CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"
|
||||
|
||||
# enable docker buildkit
|
||||
export DOCKER_BUILDKIT=1
|
||||
|
||||
@@ -21,10 +21,10 @@ docker run \
|
||||
--tty \
|
||||
--rm \
|
||||
--platform="${PLATFORM}" \
|
||||
--name="${REPOSITORY_NAME}" \
|
||||
--hostname="${REPOSITORY_NAME}" \
|
||||
--mount type=volume,volume-driver=local,source="${VOLUMENAME}",target=/data \
|
||||
--mount type=bind,source="$(pwd)"/outputs/,target=/data/outputs/ \
|
||||
--name="${REPOSITORY_NAME,,}" \
|
||||
--hostname="${REPOSITORY_NAME,,}" \
|
||||
--mount=source="${VOLUMENAME}",target=/data \
|
||||
--mount type=bind,source="$(pwd)"/outputs,target=/data/outputs \
|
||||
${MODELSPATH:+--mount="type=bind,source=${MODELSPATH},target=/data/models"} \
|
||||
${HUGGING_FACE_HUB_TOKEN:+--env="HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}"} \
|
||||
--publish=9090:9090 \
|
||||
@@ -32,7 +32,7 @@ docker run \
|
||||
${GPU_FLAGS:+--gpus="${GPU_FLAGS}"} \
|
||||
"${CONTAINER_IMAGE}" ${@:+$@}
|
||||
|
||||
echo -e "\nCleaning trash folder ..."
|
||||
# Remove Trash folder
|
||||
for f in outputs/.Trash*; do
|
||||
if [ -e "$f" ]; then
|
||||
rm -Rf "$f"
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
# Invoke.AI Architecture
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
|
||||
subgraph apps[Applications]
|
||||
webui[WebUI]
|
||||
cli[CLI]
|
||||
|
||||
subgraph webapi[Web API]
|
||||
api[HTTP API]
|
||||
sio[Socket.IO]
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
subgraph invoke[Invoke]
|
||||
direction LR
|
||||
invoker
|
||||
services
|
||||
sessions
|
||||
invocations
|
||||
end
|
||||
|
||||
subgraph core[AI Core]
|
||||
Generate
|
||||
end
|
||||
|
||||
webui --> webapi
|
||||
webapi --> invoke
|
||||
cli --> invoke
|
||||
|
||||
invoker --> services & sessions
|
||||
invocations --> services
|
||||
sessions --> invocations
|
||||
|
||||
services --> core
|
||||
|
||||
%% Styles
|
||||
classDef sg fill:#5028C8,font-weight:bold,stroke-width:2,color:#fff,stroke:#14141A
|
||||
classDef default stroke-width:2px,stroke:#F6B314,color:#fff,fill:#14141A
|
||||
|
||||
class apps,webapi,invoke,core sg
|
||||
|
||||
```
|
||||
|
||||
## Applications
|
||||
|
||||
Applications are built on top of the invoke framework. They should construct `invoker` and then interact through it. They should avoid interacting directly with core code in order to support a variety of configurations.
|
||||
|
||||
### Web UI
|
||||
|
||||
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such:
|
||||
|
||||
| Component | Description |
|
||||
| --- | --- |
|
||||
| api_app.py | Sets up the API app, annotates the OpenAPI spec with additional data, and runs the API |
|
||||
| dependencies | Creates all invoker services and the invoker, and provides them to the API |
|
||||
| events | An eventing system that could in the future be adapted to support horizontal scale-out |
|
||||
| sockets | The Socket.IO interface - handles listening to and emitting session events (events are defined in the events service module) |
|
||||
| routers | API definitions for different areas of API functionality |
|
||||
|
||||
### CLI
|
||||
|
||||
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`.
|
||||
|
||||
## Invoke
|
||||
|
||||
The Invoke framework provides the interface to the underlying AI systems and is built with flexibility and extensibility in mind. There are four major concepts: invoker, sessions, invocations, and services.
|
||||
|
||||
### Invoker
|
||||
|
||||
The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
|
||||
- **invocation services**, which are used by invocations to interact with core functionality.
|
||||
- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue.
|
||||
|
||||
### Sessions
|
||||
|
||||
Invocations and links between them form a graph, which is maintained in a session. Sessions can be queued for invocation, which will execute their graph (either the next ready invocation, or all invocations). Sessions also maintain execution history for the graph (including storage of any outputs). An invocation may be added to a session at any time, and there is capability to add and entire graph at once, as well as to automatically link new invocations to previous invocations. Invocations can not be deleted or modified once added.
|
||||
|
||||
The session graph does not support looping. This is left as an application problem to prevent additional complexity in the graph.
|
||||
|
||||
### Invocations
|
||||
|
||||
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
|
||||
|
||||
### Services
|
||||
|
||||
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
|
||||
|
||||
## AI Core
|
||||
|
||||
The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`).
|
||||
@@ -1,105 +0,0 @@
|
||||
# Invocations
|
||||
|
||||
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||
|
||||
## Creating a new invocation
|
||||
|
||||
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||
|
||||
An invocation looks like this:
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
|
||||
Each portion is important to implement correctly.
|
||||
|
||||
### Class definition and type
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
```
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||
|
||||
### Inputs
|
||||
```py
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||
```
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||
| Part | Value | Description |
|
||||
| ---- | ----- | ----------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||
|
||||
### Invoke Function
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||
|
||||
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||
|
||||
### Outputs
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||
@@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
|
||||
@@ -11,10 +11,10 @@ if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
exit -1
|
||||
fi
|
||||
|
||||
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
|
||||
VERSION=$(cd ..; python -c "from ldm.invoke import __version__ as version; print(version)")
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v3.0-latest"
|
||||
LATEST_TAG="v2.3-latest"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
|
||||
@@ -291,7 +291,7 @@ class InvokeAiInstance:
|
||||
src = Path(__file__).parents[1].expanduser().resolve()
|
||||
# if the above directory contains one of these files, we'll do a source install
|
||||
next(src.glob("pyproject.toml"))
|
||||
next(src.glob("invokeai"))
|
||||
next(src.glob("ldm"))
|
||||
except StopIteration:
|
||||
print("Unable to find a wheel or perform a source install. Giving up.")
|
||||
|
||||
@@ -342,14 +342,14 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
# this may change in the future with config refactoring!
|
||||
succeeded = False
|
||||
try:
|
||||
invokeai_configure()
|
||||
invokeai_configure.main()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
Organization of the source tree:
|
||||
|
||||
app -- Home of nodes invocations and services
|
||||
assets -- Images and other data files used by InvokeAI
|
||||
backend -- Non-user facing libraries, including the rendering
|
||||
core.
|
||||
configs -- Configuration files used at install and run times
|
||||
frontend -- User-facing scripts, including the CLI and the WebUI
|
||||
version -- Current InvokeAI version string, stored
|
||||
in version/invokeai_version.py
|
||||
|
||||
After version 2.3 is released, the ldm/invoke modules will be migrated to this location
|
||||
so that we have a proper invokeai distribution. Currently it is only being used for
|
||||
data files.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet() -> bool:
|
||||
"""
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
host = "http://huggingface.co"
|
||||
try:
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
events=events,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
||||
@@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def upload_image(file: UploadFile, request: Request):
|
||||
if not file.content_type.startswith("image"):
|
||||
return Response(status_code=415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -1,287 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.responses import Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/",
|
||||
operation_id="create_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/nodes",
|
||||
operation_id="add_node",
|
||||
responses={
|
||||
200: {"model": str},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="update_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="delete_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node to delete"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/edges",
|
||||
operation_id="add_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete(
|
||||
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
operation_id="delete_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="invoke_session",
|
||||
responses={
|
||||
200: {"model": None},
|
||||
202: {"description": "The invocation is queued"},
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
@@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
@@ -1,158 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id
|
||||
)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = dict()
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
@@ -1,202 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_commands(cls):
|
||||
return tuple(BaseCommand.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
"""Run the command. Raise ExitCli to exit."""
|
||||
pass
|
||||
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
|
||||
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
history = list(get_graph_execution_history(context.get_session()))
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
field: str = Field(description="The field to set the default for")
|
||||
value: str = Field(description="The value to set the default to, or None to clear the default")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
if self.value is None:
|
||||
if self.field in context.defaults:
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
@@ -1,275 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
|
||||
parser.exit = exit
|
||||
subparsers = parser.add_subparsers(dest="type")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
|
||||
|
||||
# Create subparsers for each command
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
)
|
||||
for field in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
pass
|
||||
|
||||
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("> ")
|
||||
except KeyboardInterrupt:
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == "":
|
||||
raise InvalidArgs("Empty command")
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in context.defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
# Invoke all current nodes to preserve operation order
|
||||
invoke_all(context)
|
||||
|
||||
# Run the command
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
link_node = context.session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
new_invocations.append((command.command, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
invoker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
||||
@@ -1,12 +0,0 @@
|
||||
import os
|
||||
|
||||
__all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if (
|
||||
f != "__init__.py"
|
||||
and os.path.isfile("%s/%s" % (dirname, f))
|
||||
and f[-3:] == ".py"
|
||||
):
|
||||
__all__.append(f[:-3])
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""Base class for all invocation outputs"""
|
||||
|
||||
# All outputs must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
|
||||
# All invocations must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
#fmt: on
|
||||
@@ -1,50 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
|
||||
# Convert back to Pillow
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
# (right now uses whatever current model is set in model manager)
|
||||
model= context.services.model_manager.get_model()
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, generate_output.image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
manager = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@@ -1,287 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
#fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
#fmt: on
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
#fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
#fomt: on
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
#fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=self.image.image_type, image_name=self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
#fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
#fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
min_y = min(0, self.y)
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(
|
||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||
)
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
#fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
if self.blur_type == "gaussian"
|
||||
else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
numpy.minimum(
|
||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||
)
|
||||
* 255
|
||||
)
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
#fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
||||
@@ -1,42 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
#fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
strength=self.strength, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
#fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
@@ -1,113 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
||||
@@ -1,81 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
__cancellations: dict[str, float]
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = dict()
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while isinstance(item, InvocationQueueItem) \
|
||||
and item.graph_execution_state_id in self.__cancellations \
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||
del self.__cancellations[graph_execution_state_id]
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
if graph_execution_state_id not in self.__cancellations:
|
||||
self.__cancellations[graph_execution_state_id] = time.time()
|
||||
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
return graph_execution_state_id in self.__cancellations
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
@@ -1,91 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
services: InvocationServices
|
||||
|
||||
def __init__(self, services: InvocationServices):
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f"queueing item {invocation.id}")
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
||||
@@ -1,62 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
#fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
@@ -1,120 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if config.precision=='float16' \
|
||||
else 'float32' if config.precision=='float32' \
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
@@ -1,121 +0,0 @@
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
)
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
@@ -1,109 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
# This should be a real base class for postprocessing functions,
|
||||
# but right now we just instantiate the existing gfpgan, esrgan
|
||||
# and codeformer functions.
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
if args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
# TO DO: refactor into separate methods
|
||||
def upscale_and_reconstruct(
|
||||
self,
|
||||
image_list,
|
||||
facetool="gfpgan",
|
||||
upscale=None,
|
||||
upscale_denoise_str=0.75,
|
||||
strength=0.0,
|
||||
codeformer_fidelity=0.75,
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
results = []
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
if strength > 0:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||||
)
|
||||
image = self.codeformer.process(
|
||||
image=image,
|
||||
strength=strength,
|
||||
device=cf_device,
|
||||
seed=seed,
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = self.esrgan.process(
|
||||
image,
|
||||
upscale[1],
|
||||
seed,
|
||||
int(upscale[0]),
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
return results
|
||||
@@ -1,138 +0,0 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
@@ -1,16 +1,5 @@
|
||||
"""
|
||||
'''
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
'''
|
||||
from .invoke_ai_web_server import InvokeAIWebServer
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Initialization file for the invokeai.generator package
|
||||
"""
|
||||
from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
)
|
||||
from .inpaint import infill_methods
|
||||
@@ -1,648 +0,0 @@
|
||||
"""
|
||||
Base class for invokeai.backend.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import dataclasses
|
||||
import diffusers
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import List, Iterator, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: int=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
scheduler: int='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: int=0.0
|
||||
seamless: bool=False
|
||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||
h_symmetry_time_pct: float=None
|
||||
v_symmetry_time_pct: float=None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list=field(default_factory=list)
|
||||
safety_checker: SafetyChecker=None
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||
operation, including the image, its seed, the model name used to generate the image
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
image: Image
|
||||
seed: int
|
||||
model_hash: str
|
||||
attention_maps_images: List[Image]
|
||||
params: Namespace
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
callback: callable=None,
|
||||
step_callback: callable=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
object. Use like this:
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
||||
for result in outputs:
|
||||
print(result.image, result.seed)
|
||||
|
||||
In the typical case of wanting to get just a single image, iterations
|
||||
defaults to 1 and do:
|
||||
|
||||
output = next(txt2img.generate(prompt='banana sushi')
|
||||
|
||||
Pass None to get an infinite iterator.
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||
for o in outputs:
|
||||
print(o.image, o.seed)
|
||||
|
||||
'''
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_info = self.model_info
|
||||
model_name = model_info['model_name']
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
def schedulers(self)->List[str]:
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
In derived classes return the name of the generator to apply.
|
||||
If you don't override will return the name of the derived
|
||||
class, which nicely parallels the generator class names.
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Image | torch.FloatTensor,
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(init_image=init_image,
|
||||
strength=strength,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .img2img import Img2Img
|
||||
return Img2Img
|
||||
|
||||
# ------------------------------------
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
**keyword_args
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(
|
||||
mask_image=mask_image,
|
||||
seam_size=seam_size,
|
||||
seam_blur=seam_blur,
|
||||
seam_strength=seam_strength,
|
||||
seam_steps=seam_steps,
|
||||
tile_size=tile_size,
|
||||
inpaint_replace=inpaint_replace,
|
||||
infill_method=infill_method,
|
||||
inpaint_width=inpaint_width,
|
||||
inpaint_height=inpaint_height,
|
||||
inpaint_fill=inpaint_fill,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"image_iterator() must be implemented in a descendent class"
|
||||
)
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
init_image=None,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
safety_checker: SafetyChecker=None,
|
||||
free_gpu_mem: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
height=height,
|
||||
step_callback=step_callback,
|
||||
threshold=threshold,
|
||||
perlin=perlin,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
attention_maps_callback=attention_maps_callback,
|
||||
**kwargs,
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc="Generating"):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
set_seed(seed)
|
||||
target_noise = self.get_noise(width, height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
set_seed(seed)
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
print("** An error occurred while getting initial noise **")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
image = make_image(x_T, seed)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_checker.check(image)
|
||||
|
||||
results.append([image, seed, attention_maps_images])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = (
|
||||
None
|
||||
if len(attention_maps_images) == 0
|
||||
else attention_maps_images[-1]
|
||||
)
|
||||
image_callback(
|
||||
image,
|
||||
seed,
|
||||
first_seed=first_seed,
|
||||
attention_maps_image=attention_maps_image,
|
||||
)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self, samples) -> Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
image = self.model.decode_latents(samples)
|
||||
return self.model.numpy_to_pil(image)[0]
|
||||
|
||||
def repaste_and_color_correct(
|
||||
self,
|
||||
result: Image.Image,
|
||||
init_image: Image.Image,
|
||||
init_mask: Image.Image,
|
||||
mask_blur_radius: int = 8,
|
||||
) -> Image.Image:
|
||||
if init_image is None or init_mask is None:
|
||||
return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = (
|
||||
init_mask.getchannel("A")
|
||||
if init_mask.mode == "RGBA"
|
||||
else init_mask.convert("L")
|
||||
)
|
||||
pil_init_image = init_image.convert(
|
||||
"RGBA"
|
||||
) # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(result, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(
|
||||
np_matched_result[:, :, :].astype(np.float32)
|
||||
- gen_means[None, None, :]
|
||||
)
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
+ init_means[None, None, :]
|
||||
)
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
kernel=np.ones((3, 3), dtype=np.uint8),
|
||||
iterations=int(mask_blur_radius / 2),
|
||||
)
|
||||
pmd = Image.fromarray(nmd, mode="L")
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(
|
||||
blurred_init_mask, self.pil_image.split()[-1]
|
||||
)
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
@staticmethod
|
||||
def sample_to_lowres_estimated_image(samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=samples.dtype,
|
||||
device=samples.device,
|
||||
)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
set_seed(seed)
|
||||
initial_noise = self.get_noise(width, height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
set_seed(seed)
|
||||
next_noise = self.get_noise(width, height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
|
||||
def get_perlin_noise(self, width, height):
|
||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
# round up to the nearest block of 8
|
||||
temp_width = int((width + 7) / 8) * 8
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack(
|
||||
[
|
||||
rand_perlin_2d(
|
||||
(temp_height, temp_width), (8, 8), device=self.model.device
|
||||
).to(fixdevice)
|
||||
for _ in range(input_channels)
|
||||
],
|
||||
dim=0,
|
||||
).to(self.model.device)
|
||||
return noise[0:4, 0:height, 0:width]
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||
|
||||
return v2
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
print(f"** creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
return torch.float16 if self.precision == "float16" else torch.float32
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height):
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device="cpu",
|
||||
).to(device)
|
||||
else:
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
)
|
||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.img2img descends from .generator
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import logging
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image,
|
||||
strength,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image,
|
||||
strength,
|
||||
steps,
|
||||
conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
seed=seed,
|
||||
)
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu").to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
from .patchmatch import PatchMatch
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
|
||||
from .seamless import configure_model_padding
|
||||
from .txt2mask import Txt2Mask
|
||||
from .util import InitImageResizer, make_grid
|
||||
|
||||
|
||||
def debug_image(
|
||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||
):
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
image_copy = debug_image.copy().convert("RGBA")
|
||||
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
|
||||
|
||||
if debug_show:
|
||||
image_copy.show()
|
||||
|
||||
if debug_result:
|
||||
return image_copy
|
||||
@@ -1,59 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(
|
||||
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
|
||||
)
|
||||
working = nn.functional.pad(
|
||||
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
|
||||
)
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
nn.modules.utils._pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = (
|
||||
"circular" if ("x" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = (
|
||||
"circular" if ("y" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
||||
@@ -7,33 +7,33 @@ import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from uuid import uuid4
|
||||
|
||||
import eventlet
|
||||
from compel.prompt_parser import Blend
|
||||
from flask import Flask, make_response, redirect, request, send_from_directory
|
||||
from flask_socketio import SocketIO
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||
from flask_socketio import SocketIO
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
import invokeai.frontend.web.dist as frontend
|
||||
|
||||
from .. import Generate
|
||||
from ..args import APP_ID, APP_VERSION, Args, calculate_init_img_hash
|
||||
from ..generator import infill_methods
|
||||
from ..globals import Globals, global_converted_ckpts_dir, global_models_dir
|
||||
from ..image_util import PngWriter, retrieve_metadata
|
||||
from ...frontend.merge.merge_diffusers import merge_diffusion_models
|
||||
from ..prompting import (
|
||||
get_prompt_structure,
|
||||
get_tokens_for_prompt_object,
|
||||
from invokeai.backend.modules.get_canvas_generation_mode import (
|
||||
get_canvas_generation_mode,
|
||||
)
|
||||
from ..stable_diffusion import PipelineIntermediateState
|
||||
from .modules.get_canvas_generation_mode import get_canvas_generation_mode
|
||||
from .modules.parameters import parameters_to_command
|
||||
from invokeai.backend.modules.parameters import parameters_to_command
|
||||
import invokeai.frontend.dist as frontend
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
|
||||
get_tokenizer
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from compel.prompt_parser import Blend
|
||||
from ldm.invoke.globals import global_models_dir
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
# Loading Arguments
|
||||
opt = Args()
|
||||
@@ -193,7 +193,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(file_path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
file_path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
response = {
|
||||
@@ -223,7 +224,7 @@ class InvokeAIWebServer:
|
||||
server="flask_socketio",
|
||||
width=1600,
|
||||
height=1000,
|
||||
port=self.port,
|
||||
port=self.port
|
||||
).run()
|
||||
except KeyboardInterrupt:
|
||||
import sys
|
||||
@@ -231,7 +232,7 @@ class InvokeAIWebServer:
|
||||
sys.exit(0)
|
||||
else:
|
||||
useSSL = args.certfile or args.keyfile
|
||||
print(">> Started Invoke AI Web Server")
|
||||
print(">> Started Invoke AI Web Server!")
|
||||
if self.host == "0.0.0.0":
|
||||
print(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||
@@ -264,14 +265,16 @@ class InvokeAIWebServer:
|
||||
# location for "finished" images
|
||||
self.result_path = args.outdir
|
||||
# temporary path for intermediates
|
||||
self.intermediate_path = os.path.join(self.result_path, "intermediates/")
|
||||
self.intermediate_path = os.path.join(
|
||||
self.result_path, "intermediates/")
|
||||
# path for user-uploaded init images and masks
|
||||
self.init_image_path = os.path.join(self.result_path, "init-images/")
|
||||
self.mask_image_path = os.path.join(self.result_path, "mask-images/")
|
||||
# path for temp images e.g. gallery generations which are not committed
|
||||
self.temp_image_path = os.path.join(self.result_path, "temp-images/")
|
||||
# path for thumbnail images
|
||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||
self.thumbnail_image_path = os.path.join(
|
||||
self.result_path, "thumbnails/")
|
||||
# txt log
|
||||
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||
# make all output paths
|
||||
@@ -296,22 +299,21 @@ class InvokeAIWebServer:
|
||||
config["infill_methods"] = infill_methods()
|
||||
socketio.emit("systemConfig", config)
|
||||
|
||||
@socketio.on("searchForModels")
|
||||
@socketio.on('searchForModels')
|
||||
def handle_search_models(search_folder: str):
|
||||
try:
|
||||
if not search_folder:
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{"search_folder": None, "found_models": None},
|
||||
{'search_folder': None, 'found_models': None},
|
||||
)
|
||||
else:
|
||||
(
|
||||
search_folder,
|
||||
found_models,
|
||||
) = self.generate.model_manager.search_models(search_folder)
|
||||
search_folder, found_models = self.generate.model_manager.search_models(
|
||||
search_folder)
|
||||
socketio.emit(
|
||||
"foundModels",
|
||||
{"search_folder": search_folder, "found_models": found_models},
|
||||
{'search_folder': search_folder,
|
||||
'found_models': found_models},
|
||||
)
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
@@ -320,11 +322,11 @@ class InvokeAIWebServer:
|
||||
@socketio.on("addNewModel")
|
||||
def handle_add_model(new_model_config: dict):
|
||||
try:
|
||||
model_name = new_model_config["name"]
|
||||
del new_model_config["name"]
|
||||
model_name = new_model_config['name']
|
||||
del new_model_config['name']
|
||||
model_attributes = new_model_config
|
||||
if len(model_attributes["vae"]) == 0:
|
||||
del model_attributes["vae"]
|
||||
if len(model_attributes['vae']) == 0:
|
||||
del model_attributes['vae']
|
||||
update = False
|
||||
current_model_list = self.generate.model_manager.list_models()
|
||||
if model_name in current_model_list:
|
||||
@@ -333,20 +335,14 @@ class InvokeAIWebServer:
|
||||
print(f">> Adding New Model: {model_name}")
|
||||
|
||||
self.generate.model_manager.add_model(
|
||||
model_name=model_name,
|
||||
model_attributes=model_attributes,
|
||||
clobber=True,
|
||||
)
|
||||
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
||||
self.generate.model_manager.commit(opt.conf)
|
||||
|
||||
new_model_list = self.generate.model_manager.list_models()
|
||||
socketio.emit(
|
||||
"newModelAdded",
|
||||
{
|
||||
"new_model_name": model_name,
|
||||
"model_list": new_model_list,
|
||||
"update": update,
|
||||
},
|
||||
{"new_model_name": model_name,
|
||||
"model_list": new_model_list, 'update': update},
|
||||
)
|
||||
print(f">> New Model Added: {model_name}")
|
||||
except Exception as e:
|
||||
@@ -361,10 +357,8 @@ class InvokeAIWebServer:
|
||||
updated_model_list = self.generate.model_manager.list_models()
|
||||
socketio.emit(
|
||||
"modelDeleted",
|
||||
{
|
||||
"deleted_model_name": model_name,
|
||||
"model_list": updated_model_list,
|
||||
},
|
||||
{"deleted_model_name": model_name,
|
||||
"model_list": updated_model_list},
|
||||
)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
except Exception as e:
|
||||
@@ -389,48 +383,41 @@ class InvokeAIWebServer:
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("convertToDiffusers")
|
||||
@socketio.on('convertToDiffusers')
|
||||
def convert_to_diffusers(model_to_convert: dict):
|
||||
try:
|
||||
if model_info := self.generate.model_manager.model_info(
|
||||
model_name=model_to_convert["model_name"]
|
||||
):
|
||||
if "weights" in model_info:
|
||||
ckpt_path = Path(model_info["weights"])
|
||||
original_config_file = Path(model_info["config"])
|
||||
model_name = model_to_convert["model_name"]
|
||||
model_description = model_info["description"]
|
||||
if (model_info := self.generate.model_manager.model_info(model_name=model_to_convert['model_name'])):
|
||||
if 'weights' in model_info:
|
||||
ckpt_path = Path(model_info['weights'])
|
||||
original_config_file = Path(model_info['config'])
|
||||
model_name = model_to_convert['model_name']
|
||||
model_description = model_info['description']
|
||||
else:
|
||||
self.socketio.emit(
|
||||
"error", {"message": "Model is not a valid checkpoint file"}
|
||||
)
|
||||
"error", {"message": "Model is not a valid checkpoint file"})
|
||||
else:
|
||||
self.socketio.emit(
|
||||
"error", {"message": "Could not retrieve model info."}
|
||||
)
|
||||
"error", {"message": "Could not retrieve model info."})
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
if original_config_file and not original_config_file.is_absolute():
|
||||
original_config_file = Path(Globals.root, original_config_file)
|
||||
original_config_file = Path(
|
||||
Globals.root, original_config_file)
|
||||
|
||||
diffusers_path = Path(
|
||||
ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
ckpt_path.parent.absolute(),
|
||||
f'{model_name}_diffusers'
|
||||
)
|
||||
|
||||
if model_to_convert["save_location"] == "root":
|
||||
if model_to_convert['save_location'] == 'root':
|
||||
diffusers_path = Path(
|
||||
global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
)
|
||||
global_converted_ckpts_dir(), f'{model_name}_diffusers')
|
||||
|
||||
if (
|
||||
model_to_convert["save_location"] == "custom"
|
||||
and model_to_convert["custom_location"] is not None
|
||||
):
|
||||
if model_to_convert['save_location'] == 'custom' and model_to_convert['custom_location'] is not None:
|
||||
diffusers_path = Path(
|
||||
model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
)
|
||||
model_to_convert['custom_location'], f'{model_name}_diffusers')
|
||||
|
||||
if diffusers_path.exists():
|
||||
shutil.rmtree(diffusers_path)
|
||||
@@ -448,67 +435,54 @@ class InvokeAIWebServer:
|
||||
new_model_list = self.generate.model_manager.list_models()
|
||||
socketio.emit(
|
||||
"modelConverted",
|
||||
{
|
||||
"new_model_name": model_name,
|
||||
"model_list": new_model_list,
|
||||
"update": True,
|
||||
},
|
||||
{"new_model_name": model_name,
|
||||
"model_list": new_model_list, 'update': True},
|
||||
)
|
||||
print(f">> Model Converted: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("mergeDiffusersModels")
|
||||
@socketio.on('mergeDiffusersModels')
|
||||
def merge_diffusers_models(model_merge_info: dict):
|
||||
try:
|
||||
models_to_merge = model_merge_info["models_to_merge"]
|
||||
models_to_merge = model_merge_info['models_to_merge']
|
||||
model_ids_or_paths = [
|
||||
self.generate.model_manager.model_name_or_path(x)
|
||||
for x in models_to_merge
|
||||
]
|
||||
self.generate.model_manager.model_name_or_path(x) for x in models_to_merge]
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths,
|
||||
model_merge_info["alpha"],
|
||||
model_merge_info["interp"],
|
||||
model_merge_info["force"],
|
||||
)
|
||||
model_ids_or_paths, model_merge_info['alpha'], model_merge_info['interp'], model_merge_info['force'])
|
||||
|
||||
dump_path = global_models_dir() / "merged_models"
|
||||
if model_merge_info["model_merge_save_path"] is not None:
|
||||
dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
dump_path = global_models_dir() / 'merged_models'
|
||||
if model_merge_info['model_merge_save_path'] is not None:
|
||||
dump_path = Path(model_merge_info['model_merge_save_path'])
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
dump_path = dump_path / model_merge_info['merged_model_name']
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
merged_model_config = dict(
|
||||
model_name=model_merge_info["merged_model_name"],
|
||||
model_name=model_merge_info['merged_model_name'],
|
||||
description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
commit_to_conf=opt.conf,
|
||||
commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
"vae", None
|
||||
):
|
||||
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None):
|
||||
print(
|
||||
f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
merged_model_config.update(vae=vae)
|
||||
|
||||
self.generate.model_manager.import_diffuser_model(
|
||||
dump_path, **merged_model_config
|
||||
)
|
||||
dump_path, **merged_model_config)
|
||||
new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
socketio.emit(
|
||||
"modelsMerged",
|
||||
{
|
||||
"merged_models": models_to_merge,
|
||||
"merged_model_name": model_merge_info["merged_model_name"],
|
||||
"model_list": new_model_list,
|
||||
"update": True,
|
||||
},
|
||||
{"merged_models": models_to_merge,
|
||||
"merged_model_name": model_merge_info['merged_model_name'],
|
||||
"model_list": new_model_list, 'update': True},
|
||||
)
|
||||
print(f">> Models Merged: {models_to_merge}")
|
||||
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
print(
|
||||
f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@@ -526,8 +500,7 @@ class InvokeAIWebServer:
|
||||
os.remove(thumbnail_path)
|
||||
except Exception as e:
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to delete {f}: {str(e)}"}
|
||||
)
|
||||
"error", {"message": f"Unable to delete {f}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit("tempFolderEmptied")
|
||||
@@ -538,7 +511,8 @@ class InvokeAIWebServer:
|
||||
def save_temp_image_to_gallery(url):
|
||||
try:
|
||||
image_path = self.get_image_path_from_url(url)
|
||||
new_path = os.path.join(self.result_path, os.path.basename(image_path))
|
||||
new_path = os.path.join(
|
||||
self.result_path, os.path.basename(image_path))
|
||||
shutil.copy2(image_path, new_path)
|
||||
|
||||
if os.path.splitext(new_path)[1] == ".png":
|
||||
@@ -551,7 +525,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(new_path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
new_path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array = [
|
||||
@@ -610,7 +585,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array.append(
|
||||
@@ -629,8 +605,7 @@ class InvokeAIWebServer:
|
||||
)
|
||||
except Exception as e:
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||
)
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit(
|
||||
@@ -680,7 +655,8 @@ class InvokeAIWebServer:
|
||||
(width, height) = pil_image.size
|
||||
|
||||
thumbnail_path = save_thumbnail(
|
||||
pil_image, os.path.basename(path), self.thumbnail_image_path
|
||||
pil_image, os.path.basename(
|
||||
path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
image_array.append(
|
||||
@@ -700,8 +676,7 @@ class InvokeAIWebServer:
|
||||
except Exception as e:
|
||||
print(f">> Unable to load {path}")
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||
)
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"})
|
||||
pass
|
||||
|
||||
socketio.emit(
|
||||
@@ -735,9 +710,10 @@ class InvokeAIWebServer:
|
||||
printable_parameters["init_mask"][:64] + "..."
|
||||
)
|
||||
|
||||
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
||||
print(f">> Facetool Parameters: {facetool_parameters}")
|
||||
print(
|
||||
f'\n>> Image Generation Parameters:\n\n{printable_parameters}\n')
|
||||
print(f'>> ESRGAN Parameters: {esrgan_parameters}')
|
||||
print(f'>> Facetool Parameters: {facetool_parameters}')
|
||||
|
||||
self.generate_images(
|
||||
generation_parameters,
|
||||
@@ -774,9 +750,11 @@ class InvokeAIWebServer:
|
||||
if postprocessing_parameters["type"] == "esrgan":
|
||||
progress.set_current_status("common.statusUpscalingESRGAN")
|
||||
elif postprocessing_parameters["type"] == "gfpgan":
|
||||
progress.set_current_status("common.statusRestoringFacesGFPGAN")
|
||||
progress.set_current_status(
|
||||
"common.statusRestoringFacesGFPGAN")
|
||||
elif postprocessing_parameters["type"] == "codeformer":
|
||||
progress.set_current_status("common.statusRestoringFacesCodeFormer")
|
||||
progress.set_current_status(
|
||||
"common.statusRestoringFacesCodeFormer")
|
||||
|
||||
socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
@@ -941,7 +919,8 @@ class InvokeAIWebServer:
|
||||
|
||||
init_img_url = generation_parameters["init_img"]
|
||||
|
||||
original_bounding_box = generation_parameters["bounding_box"].copy()
|
||||
original_bounding_box = generation_parameters["bounding_box"].copy(
|
||||
)
|
||||
|
||||
initial_image = dataURL_to_image(
|
||||
generation_parameters["init_img"]
|
||||
@@ -1018,9 +997,8 @@ class InvokeAIWebServer:
|
||||
elif generation_parameters["generation_mode"] == "img2img":
|
||||
init_img_url = generation_parameters["init_img"]
|
||||
init_img_path = self.get_image_path_from_url(init_img_url)
|
||||
generation_parameters["init_img"] = Image.open(init_img_path).convert(
|
||||
"RGB"
|
||||
)
|
||||
generation_parameters["init_img"] = Image.open(
|
||||
init_img_path).convert('RGB')
|
||||
|
||||
def image_progress(sample, step):
|
||||
if self.canceled.is_set():
|
||||
@@ -1080,11 +1058,12 @@ class InvokeAIWebServer:
|
||||
)
|
||||
|
||||
if generation_parameters["progress_latents"]:
|
||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||
image = self.generate.sample_to_lowres_estimated_image(
|
||||
sample)
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
img_base64 = image_to_dataURL(image, image_format="JPEG")
|
||||
img_base64 = image_to_dataURL(image)
|
||||
self.socketio.emit(
|
||||
"intermediateResult",
|
||||
{
|
||||
@@ -1099,7 +1078,8 @@ class InvokeAIWebServer:
|
||||
},
|
||||
)
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
def image_done(image, seed, first_seed, attention_maps_image=None):
|
||||
@@ -1126,7 +1106,8 @@ class InvokeAIWebServer:
|
||||
|
||||
progress.set_current_status("common.statusGenerationComplete")
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
all_parameters = generation_parameters
|
||||
@@ -1137,7 +1118,8 @@ class InvokeAIWebServer:
|
||||
and all_parameters["variation_amount"] > 0
|
||||
):
|
||||
first_seed = first_seed or seed
|
||||
this_variation = [[seed, all_parameters["variation_amount"]]]
|
||||
this_variation = [
|
||||
[seed, all_parameters["variation_amount"]]]
|
||||
all_parameters["with_variations"] = (
|
||||
prior_variations + this_variation
|
||||
)
|
||||
@@ -1153,13 +1135,14 @@ class InvokeAIWebServer:
|
||||
if esrgan_parameters:
|
||||
progress.set_current_status("common.statusUpscaling")
|
||||
progress.set_current_status_has_steps(False)
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
image = self.esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters["level"],
|
||||
denoise_str=esrgan_parameters["denoise_str"],
|
||||
denoise_str=esrgan_parameters['denoise_str'],
|
||||
strength=esrgan_parameters["strength"],
|
||||
seed=seed,
|
||||
)
|
||||
@@ -1167,7 +1150,7 @@ class InvokeAIWebServer:
|
||||
postprocessing = True
|
||||
all_parameters["upscale"] = [
|
||||
esrgan_parameters["level"],
|
||||
esrgan_parameters["denoise_str"],
|
||||
esrgan_parameters['denoise_str'],
|
||||
esrgan_parameters["strength"],
|
||||
]
|
||||
|
||||
@@ -1176,14 +1159,15 @@ class InvokeAIWebServer:
|
||||
|
||||
if facetool_parameters:
|
||||
if facetool_parameters["type"] == "gfpgan":
|
||||
progress.set_current_status("common.statusRestoringFacesGFPGAN")
|
||||
progress.set_current_status(
|
||||
"common.statusRestoringFacesGFPGAN")
|
||||
elif facetool_parameters["type"] == "codeformer":
|
||||
progress.set_current_status(
|
||||
"common.statusRestoringFacesCodeFormer"
|
||||
)
|
||||
"common.statusRestoringFacesCodeFormer")
|
||||
|
||||
progress.set_current_status_has_steps(False)
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
if facetool_parameters["type"] == "gfpgan":
|
||||
@@ -1213,7 +1197,8 @@ class InvokeAIWebServer:
|
||||
all_parameters["facetool_type"] = facetool_parameters["type"]
|
||||
|
||||
progress.set_current_status("common.statusSavingImage")
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
# restore the stashed URLS and discard the paths, we are about to send the result to client
|
||||
@@ -1230,7 +1215,8 @@ class InvokeAIWebServer:
|
||||
if generation_parameters["generation_mode"] == "unifiedCanvas":
|
||||
all_parameters["bounding_box"] = original_bounding_box
|
||||
|
||||
metadata = self.parameters_to_generated_image_metadata(all_parameters)
|
||||
metadata = self.parameters_to_generated_image_metadata(
|
||||
all_parameters)
|
||||
|
||||
command = parameters_to_command(all_parameters)
|
||||
|
||||
@@ -1260,27 +1246,22 @@ class InvokeAIWebServer:
|
||||
|
||||
if progress.total_iterations > progress.current_iteration:
|
||||
progress.set_current_step(1)
|
||||
progress.set_current_status("common.statusIterationComplete")
|
||||
progress.set_current_status(
|
||||
"common.statusIterationComplete")
|
||||
progress.set_current_status_has_steps(False)
|
||||
else:
|
||||
progress.mark_complete()
|
||||
|
||||
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
|
||||
self.socketio.emit(
|
||||
"progressUpdate", progress.to_formatted_dict())
|
||||
eventlet.sleep(0)
|
||||
|
||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||
tokens = (
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
self.generate.model.tokenizer, parsed_prompt
|
||||
)
|
||||
)
|
||||
attention_maps_image_base64_url = (
|
||||
None
|
||||
if attention_maps_image is None
|
||||
parsed_prompt, _ = get_prompt_structure(
|
||||
generation_parameters["prompt"])
|
||||
tokens = None if type(parsed_prompt) is Blend else \
|
||||
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
|
||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||
else image_to_dataURL(attention_maps_image)
|
||||
)
|
||||
|
||||
self.socketio.emit(
|
||||
"generationResult",
|
||||
@@ -1312,7 +1293,7 @@ class InvokeAIWebServer:
|
||||
self.generate.prompt2image(
|
||||
**generation_parameters,
|
||||
step_callback=diffusers_step_callback_adapter,
|
||||
image_callback=image_done,
|
||||
image_callback=image_done
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -1435,7 +1416,8 @@ class InvokeAIWebServer:
|
||||
self, parameters, original_image_path
|
||||
):
|
||||
try:
|
||||
current_metadata = retrieve_metadata(original_image_path)["sd-metadata"]
|
||||
current_metadata = retrieve_metadata(
|
||||
original_image_path)["sd-metadata"]
|
||||
postprocessing_metadata = {}
|
||||
|
||||
"""
|
||||
@@ -1475,7 +1457,8 @@ class InvokeAIWebServer:
|
||||
postprocessing_metadata
|
||||
)
|
||||
else:
|
||||
current_metadata["image"]["postprocessing"] = [postprocessing_metadata]
|
||||
current_metadata["image"]["postprocessing"] = [
|
||||
postprocessing_metadata]
|
||||
|
||||
return current_metadata
|
||||
|
||||
@@ -1571,7 +1554,8 @@ class InvokeAIWebServer:
|
||||
)
|
||||
elif "thumbnails" in url:
|
||||
return os.path.abspath(
|
||||
os.path.join(self.thumbnail_image_path, os.path.basename(url))
|
||||
os.path.join(self.thumbnail_image_path,
|
||||
os.path.basename(url))
|
||||
)
|
||||
else:
|
||||
return os.path.abspath(
|
||||
@@ -1617,7 +1601,7 @@ class InvokeAIWebServer:
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
def handle_exceptions(self, exception, emit_key: str = "error"):
|
||||
def handle_exceptions(self, exception, emit_key: str = 'error'):
|
||||
self.socketio.emit(emit_key, {"message": (str(exception))})
|
||||
print("\n")
|
||||
traceback.print_exc()
|
||||
@@ -1701,23 +1685,27 @@ class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
Returns a copy an image, cropped to a bounding box.
|
||||
"""
|
||||
|
||||
|
||||
def copy_image_from_bounding_box(
|
||||
image: ImageType, x: int, y: int, width: int, height: int
|
||||
) -> ImageType:
|
||||
"""
|
||||
Returns a copy an image, cropped to a bounding box.
|
||||
"""
|
||||
with image as im:
|
||||
bounds = (x, y, x + width, y + height)
|
||||
im_cropped = im.crop(bounds)
|
||||
return im_cropped
|
||||
|
||||
|
||||
"""
|
||||
Converts a base64 image dataURL into an image.
|
||||
The dataURL is split on the first commma.
|
||||
"""
|
||||
|
||||
|
||||
def dataURL_to_image(dataURL: str) -> ImageType:
|
||||
"""
|
||||
Converts a base64 image dataURL into an image.
|
||||
The dataURL is split on the first comma.
|
||||
"""
|
||||
image = Image.open(
|
||||
io.BytesIO(
|
||||
base64.decodebytes(
|
||||
@@ -1731,24 +1719,27 @@ def dataURL_to_image(dataURL: str) -> ImageType:
|
||||
return image
|
||||
|
||||
|
||||
def image_to_dataURL(image: ImageType, image_format: str = "PNG") -> str:
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
|
||||
|
||||
def image_to_dataURL(image: ImageType) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format=image_format)
|
||||
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||
image.save(buffered, format="PNG")
|
||||
image_base64 = "data:image/png;base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
||||
|
||||
"""
|
||||
Converts a base64 image dataURL into bytes.
|
||||
The dataURL is split on the first commma.
|
||||
"""
|
||||
|
||||
|
||||
def dataURL_to_bytes(dataURL: str) -> bytes:
|
||||
"""
|
||||
Converts a base64 image dataURL into bytes.
|
||||
The dataURL is split on the first comma.
|
||||
"""
|
||||
return base64.decodebytes(
|
||||
bytes(
|
||||
dataURL.split(",", 1)[1],
|
||||
@@ -1757,6 +1748,11 @@ def dataURL_to_bytes(dataURL: str) -> bytes:
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Pastes an image onto another with a bounding box.
|
||||
"""
|
||||
|
||||
|
||||
def paste_image_into_bounding_box(
|
||||
recipient_image: ImageType,
|
||||
donor_image: ImageType,
|
||||
@@ -1765,24 +1761,23 @@ def paste_image_into_bounding_box(
|
||||
width: int,
|
||||
height: int,
|
||||
) -> ImageType:
|
||||
"""
|
||||
Pastes an image onto another with a bounding box.
|
||||
"""
|
||||
with recipient_image as im:
|
||||
bounds = (x, y, x + width, y + height)
|
||||
im.paste(donor_image, bounds)
|
||||
return recipient_image
|
||||
|
||||
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
|
||||
|
||||
def save_thumbnail(
|
||||
image: ImageType,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int = 256,
|
||||
) -> str:
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
base_filename = os.path.splitext(filename)[0]
|
||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from ...args import PRECISION_CHOICES
|
||||
from ldm.invoke.args import PRECISION_CHOICES
|
||||
|
||||
|
||||
def create_cmd_parser():
|
||||
@@ -47,10 +46,10 @@ def create_cmd_parser():
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--free_gpu_mem",
|
||||
dest="free_gpu_mem",
|
||||
action="store_true",
|
||||
help="Force free gpu memory before final decoding",
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
)
|
||||
|
||||
return parser
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from PIL import Image, ImageChops
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
from typing import Union, Literal
|
||||
|
||||
# https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
|
||||
def check_for_any_transparency(img: Union[ImageType, str]) -> bool:
|
||||
@@ -87,7 +85,9 @@ def main():
|
||||
|
||||
print(
|
||||
"IMAGE WITH TRANSPARENCY, NO MASK, expect outpainting, got ",
|
||||
get_canvas_generation_mode(init_img_partial_transparency, init_mask_no_mask),
|
||||
get_canvas_generation_mode(
|
||||
init_img_partial_transparency, init_mask_no_mask
|
||||
),
|
||||
)
|
||||
|
||||
print(
|
||||
@@ -102,7 +102,9 @@ def main():
|
||||
|
||||
print(
|
||||
"IMAGE WITH TRANSPARENCY, WITH MASK, expect outpainting, got ",
|
||||
get_canvas_generation_mode(init_img_partial_transparency, init_mask_has_mask),
|
||||
get_canvas_generation_mode(
|
||||
init_img_partial_transparency, init_mask_has_mask
|
||||
),
|
||||
)
|
||||
|
||||
print(
|
||||
@@ -1,7 +1,6 @@
|
||||
from invokeai.backend.modules.parse_seed_weights import parse_seed_weights
|
||||
import argparse
|
||||
|
||||
from .parse_seed_weights import parse_seed_weights
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 292 KiB After Width: | Height: | Size: 292 KiB |
|
Before Width: | Height: | Size: 164 KiB After Width: | Height: | Size: 164 KiB |
|
Before Width: | Height: | Size: 9.5 KiB After Width: | Height: | Size: 9.5 KiB |
|
Before Width: | Height: | Size: 3.4 KiB After Width: | Height: | Size: 3.4 KiB |
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.restoration package
|
||||
"""
|
||||
from .base import Restoration
|
||||
@@ -1,118 +0,0 @@
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
self,
|
||||
image,
|
||||
generate, # current generate object
|
||||
):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process(
|
||||
self,
|
||||
extents: dict,
|
||||
opt, # current options
|
||||
orig_opt, # ones originally used to generate the image
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
# grow and mask the image
|
||||
extended_image = self._extend_all(extents)
|
||||
|
||||
# switch samplers temporarily
|
||||
curr_sampler = self.generate.sampler
|
||||
self.generate.sampler_name = opt.sampler_name
|
||||
self.generate._set_scheduler()
|
||||
|
||||
def wrapped_callback(img, seed, **kwargs):
|
||||
preferred_seed = (
|
||||
orig_opt.seed
|
||||
if orig_opt.seed is not None and orig_opt.seed >= 0
|
||||
else seed
|
||||
)
|
||||
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
|
||||
|
||||
result = self.generate.prompt2image(
|
||||
opt.prompt,
|
||||
seed=opt.seed or orig_opt.seed,
|
||||
sampler=self.generate.sampler,
|
||||
steps=opt.steps,
|
||||
cfg_scale=opt.cfg_scale,
|
||||
ddim_eta=self.generate.ddim_eta,
|
||||
width=extended_image.width,
|
||||
height=extended_image.height,
|
||||
init_img=extended_image,
|
||||
strength=0.90,
|
||||
image_callback=wrapped_callback if image_callback else None,
|
||||
seam_size=opt.seam_size or 96,
|
||||
seam_blur=opt.seam_blur or 16,
|
||||
seam_strength=opt.seam_strength or 0.7,
|
||||
seam_steps=20,
|
||||
tile_size=32,
|
||||
color_match=True,
|
||||
force_outpaint=True, # this just stops the warning about erased regions
|
||||
)
|
||||
|
||||
# swap sampler back
|
||||
self.generate.sampler = curr_sampler
|
||||
return result
|
||||
|
||||
def _extend_all(
|
||||
self,
|
||||
extents: dict,
|
||||
) -> Image:
|
||||
"""
|
||||
Extend the image in direction ('top','bottom','left','right') by
|
||||
the indicated value. The image canvas is extended, and the empty
|
||||
rectangular section will be filled with a blurred copy of the
|
||||
adjacent image.
|
||||
"""
|
||||
image = self.image
|
||||
for direction in extents:
|
||||
assert direction in [
|
||||
"top",
|
||||
"left",
|
||||
"bottom",
|
||||
"right",
|
||||
], 'Direction must be one of "top", "left", "bottom", "right"'
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels / 64) * 64
|
||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||
image = self._rotate(image, direction)
|
||||
image = self._extend(image, pixels)
|
||||
image = self._rotate(image, direction, reverse=True)
|
||||
return image
|
||||
|
||||
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
|
||||
"""
|
||||
Rotates image so that the area to extend is always at the top top.
|
||||
Simplifies logic later. The reverse argument, if true, will undo the
|
||||
previous transpose.
|
||||
"""
|
||||
transposes = {
|
||||
"right": ["ROTATE_90", "ROTATE_270"],
|
||||
"bottom": ["ROTATE_180", "ROTATE_180"],
|
||||
"left": ["ROTATE_270", "ROTATE_90"],
|
||||
}
|
||||
if direction not in transposes:
|
||||
return image
|
||||
transpose = transposes[direction][1 if reverse else 0]
|
||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
||||
|
||||
def _extend(self, image: Image, pixels: int) -> Image:
|
||||
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
|
||||
|
||||
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
|
||||
extended_img.paste(image, box=(0, pixels))
|
||||
|
||||
# now make the top part transparent to use as a mask
|
||||
alpha = extended_img.getchannel("A")
|
||||
alpha.paste(0, (0, 0, extended_img.width, pixels))
|
||||
extended_img.putalpha(alpha)
|
||||
|
||||
return extended_img
|
||||
@@ -1,82 +0,0 @@
|
||||
'''
|
||||
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
||||
and blurs images that contain potential NSFW content.
|
||||
'''
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
import traceback
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from .globals import global_cache_dir
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
class SafetyChecker(object):
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
self.device = device
|
||||
|
||||
try:
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
|
||||
def check(self, image: Image.Image):
|
||||
"""
|
||||
Check provided image against the StabilityAI safety checker and return
|
||||
|
||||
"""
|
||||
|
||||
self.safety_checker.to(self.device)
|
||||
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||
features.to(self.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = self.safety_checker(
|
||||
images=x_image, clip_input=features.pixel_values
|
||||
)
|
||||
self.safety_checker.to(CPU_DEVICE) # offload
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self, input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
if caution := self.caution_img:
|
||||
blurry.paste(caution, (0, 0), caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.training
|
||||
"""
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args
|
||||
@@ -1,19 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
from .devices import (
|
||||
CPU_DEVICE,
|
||||
CUDA_DEVICE,
|
||||
MPS_DEVICE,
|
||||
choose_precision,
|
||||
choose_torch_device,
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log
|
||||
from .util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
Initialization file for the web backend.
|
||||
"""
|
||||
from .invoke_ai_web_server import InvokeAIWebServer
|
||||
@@ -13,11 +13,16 @@ sd-inpainting-1.5:
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: True
|
||||
stable-diffusion-2.1:
|
||||
stable-diffusion-2.1-768:
|
||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1
|
||||
format: diffusers
|
||||
recommended: True
|
||||
stable-diffusion-2.1-base:
|
||||
description: Stable Diffusion version 2.1 diffusers model, trained on 512 pixel images (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-1-base
|
||||
format: diffusers
|
||||
recommended: False
|
||||
sd-inpainting-2.0:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-03
|
||||
target: invokeai.backend.stable_diffusion.diffusion.ddpm.LatentDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
@@ -19,7 +19,7 @@ model:
|
||||
embedding_reg_weight: 0.0
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ["sculpture"]
|
||||
@@ -28,7 +28,7 @@ model:
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
@@ -45,7 +45,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -68,7 +68,7 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
@@ -77,14 +77,14 @@ data:
|
||||
num_workers: 2
|
||||
wrap: false
|
||||
train:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized.PersonalizedBase
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: train
|
||||
per_image_tokens: false
|
||||
repeats: 100
|
||||
validation:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized.PersonalizedBase
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: val
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-03
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
@@ -19,7 +19,7 @@ model:
|
||||
embedding_reg_weight: 0.0
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ["painting"]
|
||||
@@ -27,7 +27,7 @@ model:
|
||||
num_vectors_per_token: 1
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
@@ -44,7 +44,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -67,7 +67,7 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
@@ -76,14 +76,14 @@ data:
|
||||
num_workers: 16
|
||||
wrap: false
|
||||
train:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized_style.PersonalizedBase
|
||||
target: ldm.data.personalized_style.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: train
|
||||
per_image_tokens: false
|
||||
repeats: 100
|
||||
validation:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized_style.PersonalizedBase
|
||||
target: ldm.data.personalized_style.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: val
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
@@ -18,7 +18,7 @@ model:
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
@@ -27,7 +27,7 @@ model:
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
@@ -36,7 +36,7 @@ model:
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
@@ -53,7 +53,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 7.5e-05
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
@@ -18,7 +18,7 @@ model:
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
@@ -27,7 +27,7 @@ model:
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
@@ -36,7 +36,7 @@ model:
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
@@ -53,7 +53,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-03
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
@@ -19,7 +19,7 @@ model:
|
||||
embedding_reg_weight: 0.0
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
@@ -28,7 +28,7 @@ model:
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
@@ -45,7 +45,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -68,7 +68,7 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
@@ -77,14 +77,14 @@ data:
|
||||
num_workers: 2
|
||||
wrap: false
|
||||
train:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized.PersonalizedBase
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: train
|
||||
per_image_tokens: false
|
||||
repeats: 100
|
||||
validation:
|
||||
target: invokeai.backend.stable_diffusion.data.personalized.PersonalizedBase
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: val
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: invokeai.backend.stable_diffusion.diffusion.ddpm.LatentDiffusion
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
@@ -19,7 +19,7 @@ model:
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
@@ -38,7 +38,7 @@ model:
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
@@ -62,7 +62,7 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: invokeai.backend.stable_diffusion.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
@@ -3,6 +3,3 @@ dist/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
@@ -30,12 +30,8 @@ module.exports = {
|
||||
radix: 'error',
|
||||
'space-before-blocks': 'error',
|
||||
'import/prefer-default-export': 'off',
|
||||
'@typescript-eslint/no-unused-vars': [
|
||||
'warn',
|
||||
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
|
||||
],
|
||||
'@typescript-eslint/no-unused-vars': ['warn', { varsIgnorePattern: '_+' }],
|
||||
'prettier/prettier': ['error', { endOfLine: 'auto' }],
|
||||
'@typescript-eslint/ban-ts-comment': 'warn',
|
||||
},
|
||||
settings: {
|
||||
react: {
|
||||