Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
b897ca18ce feat(nodes): wip outputs_service 2023-05-17 00:42:27 +10:00
608 changed files with 24110 additions and 20788 deletions

14
.github/CODEOWNERS vendored
View File

@@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername
/docs/ @lstein @tildebyte @blessedcoolant
/mkdocs.yml @lstein @blessedcoolant
# nodes
@@ -18,17 +18,17 @@
/invokeai/version @lstein @blessedcoolant
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
/invokeai/backend @blessedcoolant @psychedelicious @lstein
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant
/invokeai/frontend/training @lstein @blessedcoolant
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant

View File

@@ -80,6 +80,11 @@ jobs:
uses: actions/checkout@v3
- name: set test prompt to main branch validation
if: ${{ github.ref == 'refs/heads/main' }}
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
- name: set test prompt to Pull Request validation
if: ${{ github.ref != 'refs/heads/main' }}
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
@@ -100,6 +105,12 @@ jobs:
id: run-pytest
run: pytest
- name: set INVOKEAI_OUTDIR
run: >
python -c
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
>> ${{ matrix.github-env }}
- name: run invokeai-configure
id: run-preload-models
env:
@@ -118,21 +129,15 @@ jobs:
HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: >
invokeai
--no-patchmatch
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
- name: Archive results
id: archive-results
env:
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
uses: actions/upload-artifact@v3
with:
name: results

2
.gitignore vendored
View File

@@ -201,8 +201,6 @@ checkpoints
# If it's a Mac
.DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore
!invokeai/frontend/web/*

Binary file not shown.

View File

@@ -0,0 +1,164 @@
@echo off
@rem This script will install git (if not found on the PATH variable)
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
@rem For users who already have git, this step will be skipped.
@rem Next, it'll download the project's source code.
@rem Then it will download a self-contained, standalone Python and unpack it.
@rem Finally, it'll create the Python virtual environment and preload the models.
@rem This enables a user to install this project without manually installing git or Python
@rem change to the script's directory
PUSHD "%~dp0"
set "no_cache_dir=--no-cache-dir"
if "%1" == "use-cache" (
set "no_cache_dir="
)
echo ***** Installing InvokeAI.. *****
@rem Config
set INSTALL_ENV_DIR=%cd%\installer_files\env
@rem https://mamba.readthedocs.io/en/latest/installation.html
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
set PACKAGES_TO_INSTALL=
call git --version >.tmp1 2>.tmp2
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
@rem Cleanup
del /q .tmp1 .tmp2
@rem (if necessary) install git into a contained environment
if "%PACKAGES_TO_INSTALL%" NEQ "" (
@rem download micromamba
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
@rem test the mamba binary
echo ***** Micromamba version: *****
call micromamba.exe --version
@rem create the installer env
if not exist "%INSTALL_ENV_DIR%" (
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
)
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
if not exist "%INSTALL_ENV_DIR%" (
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
pause
exit /b
)
)
del /q micromamba.exe
@rem For 'git' only
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
@rem Download/unpack/clean up InvokeAI release sourceball
set err_msg=----- InvokeAI source download failed -----
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
if %errorlevel% neq 0 goto err_exit
set err_msg=----- InvokeAI source unpack failed -----
tar -zxf InvokeAI.tgz
if %errorlevel% neq 0 goto err_exit
del /q InvokeAI.tgz
set err_msg=----- InvokeAI source copy failed -----
cd InvokeAI-*
xcopy . .. /e /h
if %errorlevel% neq 0 goto err_exit
cd ..
@rem cleanup
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
rd /s /q .dev_scripts .github docker-build tests
del /q requirements.in requirements-mkdocs.txt shell.nix
echo ***** Unpacked InvokeAI source *****
@rem Download/unpack/clean up python-build-standalone
set err_msg=----- Python download failed -----
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
if %errorlevel% neq 0 goto err_exit
set err_msg=----- Python unpack failed -----
tar -zxf python.tgz
if %errorlevel% neq 0 goto err_exit
del /q python.tgz
echo ***** Unpacked python-build-standalone *****
@rem create venv
set err_msg=----- problem creating venv -----
.\python\python -E -s -m venv .venv
if %errorlevel% neq 0 goto err_exit
call .venv\Scripts\activate.bat
echo ***** Created Python virtual environment *****
@rem Print venv's Python version
set err_msg=----- problem calling venv's python -----
echo We're running under
.venv\Scripts\python --version
if %errorlevel% neq 0 goto err_exit
set err_msg=----- pip update failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
if %errorlevel% neq 0 goto err_exit
echo ***** Updated pip and wheel *****
set err_msg=----- requirements file copy failed -----
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
if %errorlevel% neq 0 goto err_exit
set err_msg=----- main pip install failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
if %errorlevel% neq 0 goto err_exit
echo ***** Installed Python dependencies *****
set err_msg=----- InvokeAI setup failed -----
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
if %errorlevel% neq 0 goto err_exit
copy binary_installer\invoke.bat.in .\invoke.bat
echo ***** Installed invoke launcher script ******
@rem more cleanup
rd /s /q binary_installer installer_files
@rem preload the models
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
set err_msg=----- model download clone failed -----
if %errorlevel% neq 0 goto err_exit
deactivate
echo ***** Finished downloading models *****
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
pause
exit
:err_exit
echo %err_msg%
pause
exit

View File

@@ -0,0 +1,235 @@
#!/usr/bin/env bash
# ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
cd "$scriptdir"
set -euo pipefail
IFS=$'\n\t'
function _err_exit {
if test "$1" -ne 0
then
echo -e "Error code $1; Error caught was '$2'"
read -p "Press any key to exit..."
exit
fi
}
# This script will install git (if not found on the PATH variable)
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
# For users who already have git, this step will be skipped.
# Next, it'll download the project's source code.
# Then it will download a self-contained, standalone Python and unpack it.
# Finally, it'll create the Python virtual environment and preload the models.
# This enables a user to install this project without manually installing git or Python
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
export no_cache_dir="--no-cache-dir"
if [ $# -ge 1 ]; then
if [ "$1" = "use-cache" ]; then
export no_cache_dir=""
fi
fi
OS_NAME=$(uname -s)
case "${OS_NAME}" in
Linux*) OS_NAME="linux";;
Darwin*) OS_NAME="darwin";;
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
esac
OS_ARCH=$(uname -m)
case "${OS_ARCH}" in
x86_64*) ;;
arm64*) ;;
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
esac
# https://mamba.readthedocs.io/en/latest/installation.html
MAMBA_OS_NAME=$OS_NAME
MAMBA_ARCH=$OS_ARCH
if [ "$OS_NAME" == "darwin" ]; then
MAMBA_OS_NAME="osx"
fi
if [ "$OS_ARCH" == "linux" ]; then
MAMBA_ARCH="aarch64"
fi
if [ "$OS_ARCH" == "x86_64" ]; then
MAMBA_ARCH="64"
fi
PY_ARCH=$OS_ARCH
if [ "$OS_ARCH" == "arm64" ]; then
PY_ARCH="aarch64"
fi
# Compute device ('cd' segment of reqs files) detect goes here
# This needs a ton of work
# Suggestions:
# - lspci
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
# - Surely there's a similar utility for AMD?
CD="cuda"
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
CD="mps"
fi
# config
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
if [ "$OS_NAME" == "darwin" ]; then
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
elif [ "$OS_NAME" == "linux" ]; then
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
fi
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
PACKAGES_TO_INSTALL=""
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
# (if necessary) install git and conda into a contained environment
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
# download micromamba
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
chmod u+x ./micromamba
# test the mamba binary
echo -e "\n***** Micromamba version: *****\n"
./micromamba --version
# create the installer env
if [ ! -e "$INSTALL_ENV_DIR" ]; then
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
fi
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
if [ ! -e "$INSTALL_ENV_DIR" ]; then
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
exit
fi
fi
rm -f micromamba.exe
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
# Download/unpack/clean up InvokeAI release sourceball
_err_msg="\n----- InvokeAI source download failed -----\n"
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
_err_exit $? _err_msg
_err_msg="\n----- InvokeAI source unpack failed -----\n"
tar -zxf InvokeAI.tgz
_err_exit $? _err_msg
rm -f InvokeAI.tgz
_err_msg="\n----- InvokeAI source copy failed -----\n"
cd InvokeAI-*
cp -r . ..
_err_exit $? _err_msg
cd ..
# cleanup
rm -rf InvokeAI-*/
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
echo -e "\n***** Unpacked InvokeAI source *****\n"
# Download/unpack/clean up python-build-standalone
_err_msg="\n----- Python download failed -----\n"
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
_err_exit $? _err_msg
_err_msg="\n----- Python unpack failed -----\n"
tar -zxf python.tgz
_err_exit $? _err_msg
rm -f python.tgz
echo -e "\n***** Unpacked python-build-standalone *****\n"
# create venv
_err_msg="\n----- problem creating venv -----\n"
if [ "$OS_NAME" == "darwin" ]; then
# patch sysconfig so that extensions can build properly
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
PYTHON_INSTALL_DIR="$(pwd)/python"
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
TMPFILE="$(mktemp)"
chmod +w "${SYSCONFIG}"
cp "${SYSCONFIG}" "${TMPFILE}"
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
rm -f "${TMPFILE}"
fi
./python/bin/python3 -E -s -m venv .venv
_err_exit $? _err_msg
source .venv/bin/activate
echo -e "\n***** Created Python virtual environment *****\n"
# Print venv's Python version
_err_msg="\n----- problem calling venv's python -----\n"
echo -e "We're running under"
.venv/bin/python3 --version
_err_exit $? _err_msg
_err_msg="\n----- pip update failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
_err_exit $? _err_msg
echo -e "\n***** Updated pip *****\n"
_err_msg="\n----- requirements file copy failed -----\n"
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
_err_exit $? _err_msg
_err_msg="\n----- main pip install failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
_err_exit $? _err_msg
echo -e "\n***** Installed Python dependencies *****\n"
_err_msg="\n----- InvokeAI setup failed -----\n"
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
_err_exit $? _err_msg
echo -e "\n***** Installed InvokeAI *****\n"
cp binary_installer/invoke.sh.in ./invoke.sh
chmod a+rx ./invoke.sh
echo -e "\n***** Installed invoke launcher script ******\n"
# more cleanup
rm -rf binary_installer/ installer_files/
# preload the models
.venv/bin/python3 scripts/configure_invokeai.py
_err_msg="\n----- model download clone failed -----\n"
_err_exit $? _err_msg
deactivate
echo -e "\n***** Finished downloading models *****\n"
echo "All done! Run the command"
echo " $scriptdir/invoke.sh"
echo "to start InvokeAI."
read -p "Press any key to exit..."
exit

View File

@@ -0,0 +1,36 @@
@echo off
PUSHD "%~dp0"
call .venv\Scripts\activate.bat
echo Do you want to generate images using the
echo 1. command-line
echo 2. browser-based UI
echo OR
echo 3. open the developer console
set /p choice="Please enter 1, 2 or 3: "
if /i "%choice%" == "1" (
echo Starting the InvokeAI command-line.
.venv\Scripts\python scripts\invoke.py %*
) else if /i "%choice%" == "2" (
echo Starting the InvokeAI browser-based UI.
.venv\Scripts\python scripts\invoke.py --web %*
) else if /i "%choice%" == "3" (
echo Developer Console
echo Python command is:
where python
echo Python version is:
python --version
echo *************************
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
echo so that you can troubleshoot this InvokeAI installation as necessary.
echo *************************
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) else (
echo Invalid selection
pause
exit /b
)
deactivate

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env sh
set -eu
. .venv/bin/activate
# set required env var for torch on mac MPS
if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
echo "Do you want to generate images using the"
echo "1. command-line"
echo "2. browser-based UI"
echo "OR"
echo "3. open the developer console"
echo "Please enter 1, 2, or 3:"
read choice
case $choice in
1)
printf "\nStarting the InvokeAI command-line..\n";
.venv/bin/python scripts/invoke.py $*;
;;
2)
printf "\nStarting the InvokeAI browser-based UI..\n";
.venv/bin/python scripts/invoke.py --web $*;
;;
3)
printf "\nDeveloper Console:\n";
printf "Python command is:\n\t";
which python;
printf "Python version is:\n\t";
python --version;
echo "*************************"
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
printf "*************************\n"
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
/usr/bin/env "$SHELL";
;;
*)
echo "Invalid selection";
exit
;;
esac

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
InvokeAI
Project homepage: https://github.com/invoke-ai/InvokeAI
Installation on Windows:
NOTE: You might need to enable Windows Long Paths. If you're not sure,
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
file. Note that you will need to have admin privileges in order to
do this.
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
Installation on Linux and Mac:
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
file (on Linux/Mac) to start InvokeAI.

View File

@@ -0,0 +1,33 @@
--prefer-binary
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
--extra-index-url https://download.pytorch.org/whl/cu116
--trusted-host https://download.pytorch.org
accelerate~=0.15
albumentations
diffusers[torch]~=0.11
einops
eventlet
flask_cors
flask_socketio
flaskwebgui==1.0.3
getpass_asterisk
imageio-ffmpeg
pyreadline3
realesrgan
send2trash
streamlit
taming-transformers-rom1504
test-tube
torch-fidelity
torch==1.12.1 ; platform_system == 'Darwin'
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
torchvision==0.13.1 ; platform_system == 'Darwin'
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
transformers
picklescan
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip

View File

@@ -19,56 +19,31 @@ An invocation looks like this:
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
type: Literal["upscale"] = "upscale"
type: Literal['upscale'] = 'upscale'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
image = ImageField(image_type = image_type, image_name = image_name)
)
```
Each portion is important to implement correctly.
@@ -120,67 +95,25 @@ Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
### Config
```py
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
```
This is an optional configuration for the invocation. It inherits from
pydantic's model `Config` class, and it used primarily to customize the
autogenerated OpenAPI schema.
The UI relies on the OpenAPI schema in two ways:
- An API client & Typescript types are generated from it. This happens at build
time.
- The node editor parses the schema into a template used by the UI to create the
node editor UI. This parsing happens at runtime.
In this example, a `ui` key has been added to the `schema_extra` dict to provide
some tags for the UI, to facilitate filtering nodes.
See the Schema Generation section below for more information.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
image = ImageField(image_type = image_type, image_name = image_name)
)
```
@@ -202,16 +135,9 @@ scenarios. If you need functionality, please provide it as a service in the
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal['image'] = 'image'
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
image: ImageField = Field(default=None, description="The output image")
```
Output classes look like an invocation class without the invoke method. Prefer
@@ -242,36 +168,35 @@ Here's that `ImageOutput` class, without the needed schema customisation:
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
```
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
`image`, `width` and `height` properties marked as optional, even though we know
they will always have a value.
The generated OpenAPI schema, and all clients/types generated from it, will have
the `type` and `image` properties marked as optional, even though we know they
will always have a value by the time we can interact with them via the API.
Here's the same class, but with the schema customisation added:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
# Add schema customization
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
schema_extra = {
'required': [
'type',
'image',
]
}
```
With the customization in place, the schema will now show these properties as
required, obviating the need for extensive null checks in client code.
The resultant schema (and any API client or types generated from it) will now
have see `type` as string literal `"image"` and `image` as an `ImageField`
object.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577>

View File

@@ -1,171 +0,0 @@
---
title: Controlling Logging
---
# :material-image-off: Controlling Logging
## Controlling How InvokeAI Logs Status Messages
InvokeAI logs status messages using a configurable logging system. You
can log to the terminal window, to a designated file on the local
machine, to the syslog facility on a Linux or Mac, or to a properly
configured web server. You can configure several logs at the same
time, and control the level of message logged and the logging format
(to a limited extent).
Three command-line options control logging:
### `--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console",
"file", "syslog" and "http". To specify more than one, separate them
by spaces:
```bash
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only
the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to
indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```bash
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog
system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the
specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be
written to the command line window in which InvokeAI was launched. By
default, the color formatter will be used unless overridden by
`--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written
to the file indicated in the path argument. By default, the "plain"
format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog
system. There are a variety of ways to control where the log message
is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard
port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using
facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0
and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging
system will look for a UNIX socket named `/dev/log`, and if not found
try to send a UDP message to `localhost`. The Macintosh OS used to
support logging to a socket named `/var/run/syslog`, but this feature
has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages
when a particular URL is requested, you can log using the "http"
method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL
accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section
to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```

View File

@@ -57,9 +57,6 @@ Personalize models by adding your own style or subjects.
## * [The NSFW Checker](NSFW.md)
Prevent InvokeAI from displaying unwanted racy images.
## * [Controlling Logging](LOGGING.md)
Control how InvokeAI logs status messages.
## * [Miscellaneous](OTHER.md)
Run InvokeAI on Google Colab, generate images with repeating patterns,
batch process a file of prompts, increase the "creativity" of image

View File

@@ -216,7 +216,7 @@ manager, please follow these steps:
9. Run the command-line- or the web- interface:
From within INVOKEAI_ROOT, activate the environment
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
the script `invokeai`. If the virtual environment you selected is NOT inside
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
`--root_dir \path\to\invokeai` to the commands below:

View File

@@ -7,42 +7,42 @@ call .venv\Scripts\activate.bat
set INVOKEAI_ROOT=.
:start
echo Desired action:
echo 1. Generate images with the browser-based interface
echo 2. Explore InvokeAI nodes using a command-line interface
echo 3. Run textual inversion training
echo 4. Merge models (diffusers type only)
echo 5. Download and install models
echo 6. Change InvokeAI startup options
echo 7. Re-run the configure script to fix a broken install
echo 8. Open the developer console
echo 9. Update InvokeAI
echo 10. Command-line help
echo Q - Quit
set /P choice="Please enter 1-10, Q: [2] "
if not defined choice set choice=2
IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai-web.exe %*
) ELSE IF /I "%choice%" == "2" (
echo Do you want to generate images using the
echo 1. command-line interface
echo 2. browser-based UI
echo 3. run textual inversion training
echo 4. merge models (diffusers type only)
echo 5. download and install models
echo 6. change InvokeAI startup options
echo 7. re-run the configure script to fix a broken install
echo 8. open the developer console
echo 9. update InvokeAI
echo 10. command-line help
echo Q - quit
set /P restore="Please enter 1-10, Q: [2] "
if not defined restore set restore=2
IF /I "%restore%" == "1" (
echo Starting the InvokeAI command-line..
python .venv\Scripts\invokeai.exe %*
) ELSE IF /I "%choice%" == "3" (
) ELSE IF /I "%restore%" == "2" (
echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai.exe --web %*
) ELSE IF /I "%restore%" == "3" (
echo Starting textual inversion training..
python .venv\Scripts\invokeai-ti.exe --gui
) ELSE IF /I "%choice%" == "4" (
) ELSE IF /I "%restore%" == "4" (
echo Starting model merging script..
python .venv\Scripts\invokeai-merge.exe --gui
) ELSE IF /I "%choice%" == "5" (
) ELSE IF /I "%restore%" == "5" (
echo Running invokeai-model-install...
python .venv\Scripts\invokeai-model-install.exe
) ELSE IF /I "%choice%" == "6" (
) ELSE IF /I "%restore%" == "6" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%choice%" == "7" (
) ELSE IF /I "%restore%" == "7" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --default_only
) ELSE IF /I "%choice%" == "8" (
) ELSE IF /I "%restore%" == "8" (
echo Developer Console
echo Python command is:
where python
@@ -54,15 +54,15 @@ IF /I "%choice%" == "1" (
echo *************************
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) ELSE IF /I "%choice%" == "9" (
) ELSE IF /I "%restore%" == "9" (
echo Running invokeai-update...
python .venv\Scripts\invokeai-update.exe %*
) ELSE IF /I "%choice%" == "10" (
) ELSE IF /I "%restore%" == "10" (
echo Displaying command line help...
python .venv\Scripts\invokeai.exe --help %*
pause
exit /b
) ELSE IF /I "%choice%" == "q" (
) ELSE IF /I "%restore%" == "q" (
echo Goodbye!
goto ending
) ELSE (

View File

@@ -1,10 +1,5 @@
#!/bin/bash
# MIT License
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
# Copyright 2023, The InvokeAI Development Team
####
# This launch script assumes that:
# 1. it is located in the runtime directory,
@@ -16,168 +11,85 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
# ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
cd "$scriptdir"
. .venv/bin/activate
export INVOKEAI_ROOT="$scriptdir"
PARAMS=$@
# Check to see if dialog is installed (it seems to be fairly standard, but good to check regardless) and if the user has passed the --no-tui argument to disable the dialog TUI
tui=true
if command -v dialog &>/dev/null; then
# This must use $@ to properly loop through the arguments passed by the user
for arg in "$@"; do
if [ "$arg" == "--no-tui" ]; then
tui=false
# Remove the --no-tui argument to avoid errors later on when passing arguments to InvokeAI
PARAMS=$(echo "$PARAMS" | sed 's/--no-tui//')
break
fi
done
else
tui=false
fi
# Set required env var for torch on mac MPS
# set required env var for torch on mac MPS
if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
# Primary function for the case statement to determine user input
do_choice() {
case $1 in
1)
clear
printf "Generate images with a browser-based interface\n"
invokeai-web $PARAMS
;;
2)
clear
printf "Explore InvokeAI nodes using a command-line interface\n"
invokeai $PARAMS
;;
3)
clear
printf "Textual inversion training\n"
invokeai-ti --gui $PARAMS
;;
4)
clear
printf "Merge models (diffusers type only)\n"
invokeai-merge --gui $PARAMS
;;
5)
clear
printf "Download and install models\n"
invokeai-model-install --root ${INVOKEAI_ROOT}
;;
6)
clear
printf "Change InvokeAI startup options\n"
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
;;
7)
clear
printf "Re-run the configure script to fix a broken install\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
clear
printf "Open the developer console\n"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
clear
printf "Update InvokeAI\n"
invokeai-update
;;
10)
clear
printf "Command-line help\n"
invokeai --help
;;
"HELP 1")
clear
printf "Command-line help\n"
invokeai --help
;;
*)
clear
printf "Exiting...\n"
exit
;;
esac
clear
}
# Dialog-based TUI for launcing Invoke functions
do_dialog() {
options=(
1 "Generate images with a browser-based interface"
2 "Generate images using a command-line interface"
3 "Textual inversion training"
4 "Merge models (diffusers type only)"
5 "Download and install models"
6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install"
8 "Open the developer console"
9 "Update InvokeAI")
choice=$(dialog --clear \
--backtitle "\Zb\Zu\Z3InvokeAI" \
--colors \
--title "What would you like to do?" \
--ok-label "Run" \
--cancel-label "Exit" \
--help-button \
--help-label "CLI Help" \
--menu "Select an option:" \
0 0 0 \
"${options[@]}" \
2>&1 >/dev/tty) || clear
do_choice "$choice"
clear
}
# Command-line interface for launching Invoke functions
do_line_input() {
clear
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n"
printf "2: Explore InvokeAI nodes using the command-line interface\n"
printf "3: Run textual inversion training\n"
printf "4: Merge models (diffusers type only)\n"
printf "5: Download and install models\n"
printf "6: Change InvokeAI startup options\n"
printf "7: Re-run the configure script to fix a broken install\n"
printf "8: Open the developer console\n"
printf "9: Update InvokeAI\n"
printf "10: Command-line help\n"
printf "Q: Quit\n\n"
read -p "Please enter 1-10, Q: [1] " yn
choice=${yn:='1'}
do_choice $choice
clear
}
# Main IF statement for launching Invoke with either the TUI or CLI, and for checking if the user is in the developer console
if [ "$0" != "bash" ]; then
while true; do
if $tui; then
# .dialogrc must be located in the same directory as the invoke.sh script
export DIALOGRC="./.dialogrc"
do_dialog
else
do_line_input
fi
done
while true
do
echo "Do you want to generate images using the"
echo "1. command-line interface"
echo "2. browser-based UI"
echo "3. run textual inversion training"
echo "4. merge models (diffusers type only)"
echo "5. download and install models"
echo "6. change InvokeAI startup options"
echo "7. re-run the configure script to fix a broken install"
echo "8. open the developer console"
echo "9. update InvokeAI"
echo "10. command-line help"
echo "Q - Quit"
echo ""
read -p "Please enter 1-10, Q: [2] " yn
choice=${yn:='2'}
case $choice in
1)
echo "Starting the InvokeAI command-line..."
invokeai $@
;;
2)
echo "Starting the InvokeAI browser-based UI..."
invokeai --web $@
;;
3)
echo "Starting Textual Inversion:"
invokeai-ti --gui $@
;;
4)
echo "Merging Models:"
invokeai-merge --gui $@
;;
5)
invokeai-model-install --root ${INVOKEAI_ROOT}
;;
6)
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
;;
7)
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
echo "Developer Console:"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
echo "Update:"
invokeai-update
;;
10)
invokeai --help
;;
[qQ])
exit 0
;;
*)
echo "Invalid selection"
exit;;
esac
done
else # in developer console
python --version
printf "Press ^D to exit\n"
echo "Press ^D to exit"
export PS1="(InvokeAI) \u@\h \w> "
fi

View File

@@ -1,25 +1,25 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.models.image import ImageField
from invokeai.app.services.outputs_sqlite import OutputsSqliteItemStorage
import invokeai.backend.util.logging as logger
from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@@ -39,63 +39,55 @@ def check_internet() -> bool:
return False
logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.info(f"Internet connectivity is {config.internet_available}")
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
events = FastAPIEventService(event_handler_id)
output_folder = config.output_path
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
# TODO: build a file/path manager?
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=get_model_manager(config, logger),
model_manager=get_model_manager(config,logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
outputs=OutputsSqliteItemStorage(filename=db_location),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager,
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
restoration=RestorationServices(config,logger),
)
create_system_graphs(services.graph_library)

View File

@@ -0,0 +1,40 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@@ -1,250 +1,148 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/",
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
response_model=ImageDTO,
)
async def upload_image(
file: UploadFile,
request: Request,
response: Response,
image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO:
"""Uploads an image"""
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
try:
image_dto = ApiDependencies.invoker.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.EXTERNAL,
image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
response.status_code = 201
response.headers["Location"] = image_dto.image_url
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
return image_dto
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create image")
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
@images_router.patch(
"/{image_origin}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
200: {
"description": "Return the image thumbnail",
"content": {"image/webp": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True
)
return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except Exception as e:
raise HTTPException(status_code=404)
return res
@images_router.get(
"/",
operation_id="list_images_with_metadata",
response_model=OffsetPaginatedResults[ImageDTO],
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images_with_metadata(
image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list"
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
)
return image_dtos
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@@ -3,7 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@@ -11,20 +11,11 @@ from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
#This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
@@ -42,21 +33,30 @@ app.add_middleware(
middleware_id=event_handler_id,
)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
config = Args()
config.parse_args()
ApiDependencies.initialize(
config=app_config, event_handler_id=event_handler_id, logger=logger
config=config, event_handler_id=event_handler_id, logger=logger
)
@@ -74,9 +74,10 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@@ -123,7 +124,7 @@ def custom_openapi():
app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
@@ -142,21 +143,19 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/",
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
html=True
), name="ui"
)
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
def invoke_api():
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@@ -285,19 +285,3 @@ class DrawExecutionGraphCommand(BaseCommand):
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class SortedHelpFormatter(argparse.HelpFormatter):
def _iter_indented_subactions(self, action):
try:
get_subactions = action._get_subactions
except AttributeError:
pass
else:
self._indent()
if isinstance(action, argparse._SubParsersAction):
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
yield subaction
else:
for subaction in get_subactions():
yield subaction
self._dedent()

View File

@@ -11,10 +11,9 @@ from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
from ..services.invocation_services import InvocationServices
# singleton object, class variable
completer = None
@@ -132,13 +131,13 @@ class Completer(object):
readline.redisplay()
self.linebuffer = None
def set_autocompleter(services: InvocationServices) -> Completer:
def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer
if completer:
return completer
completer = Completer(services.model_manager)
completer = Completer(model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
@@ -154,7 +153,7 @@ def set_autocompleter(services: InvocationServices) -> Completer:
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(services.configuration.root_dir / ".invoke_history")
histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)

View File

@@ -4,33 +4,23 @@ import argparse
import os
import re
import shlex
import sys
import time
from typing import (
Union,
get_type_hints,
)
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
@@ -38,7 +28,7 @@ from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_file_storage import DiskImageFileStorage
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
@@ -53,6 +43,7 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
@@ -73,7 +64,7 @@ def add_invocation_args(command_parser):
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
parser = argparse.ArgumentParser()
def exit(*args, **kwargs):
raise InvalidArgs
@@ -196,66 +187,45 @@ def invoke_all(context: CliContext):
raise SessionError()
def invoke_cli():
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')
invocation_commands = parser.parse_args().commands
# get the optional file to read commands from.
# Simplest is to use it for STDIN
if infile := config.from_file:
sys.stdin = open(infile,"r")
def invoke_cli():
config = Args()
config.parse_args()
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
set_autocompleter(model_manager)
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
if config.use_memory_db:
db_location = ":memory:"
else:
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=images,
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager,
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
@@ -271,18 +241,10 @@ def invoke_cli():
# print(services.session_manager.list())
context = CliContext(invoker, session, parser)
set_autocompleter(services)
command_line_args_exist = len(invocation_commands) > 0
done = False
while not done:
while True:
try:
if command_line_args_exist:
cmd_input = invocation_commands.pop(0)
done = len(invocation_commands) == 0
else:
cmd_input = input("invoke> ")
cmd_input = input("invoke> ")
except (KeyboardInterrupt, EOFError):
# Ctrl-c exits
break
@@ -406,9 +368,6 @@ def invoke_cli():
invoker.services.logger.warning('Invalid command, use "help" to list commands')
continue
except ValidationError:
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
except SessionError:
# Start a new session
invoker.services.logger.warning("Session error: creating a new session")

View File

@@ -1,15 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
from ..services.invocation_services import InvocationServices
class InvocationContext:
@@ -78,7 +75,6 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
@@ -96,7 +92,6 @@ class UIConfig(TypedDict, total=False):
"image",
"latents",
"model",
"control",
],
]
tags: List[str]

View File

@@ -1,9 +1,9 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from typing import Literal, Optional
import numpy as np
from pydantic import Field, validator
from pydantic import Field
from invokeai.app.util.misc import SEED_MAX, get_random_seed
@@ -22,17 +22,9 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
"""Creates a range"""
type: Literal["range"] = "range"
@@ -41,34 +33,12 @@ class RangeInvocation(BaseInvocation):
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
if "start" in values and v <= values["start"]:
raise ValueError("stop must be greater than start")
return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = Field(default=0, description="The start of the range")
size: int = Field(default=1, description="The number of values")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.start + self.size, self.step))
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@@ -3,7 +3,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.prompting.conditioning import try_parse_legacy_blend
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@@ -14,9 +13,11 @@ from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment, Conjunction,
Fragment,
)
from invokeai.backend.globals import Globals
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
@@ -94,29 +95,32 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
truncate_long_prompts=True, # TODO:
)
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
if legacy_blend is not None:
conjunction = legacy_blend
else:
conjunction = Compel.parse_prompt_string(prompt_str)
# TODO: support legacy blend?
if context.services.configuration.log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
@@ -126,22 +130,14 @@ class CompelInvocation(BaseInvocation):
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else:
@@ -176,22 +172,6 @@ def get_tokens_for_prompt_object(
return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):

View File

@@ -1,459 +0,0 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
import numpy as np
from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
LineartDetector,
LineartAnimeDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
)
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
# "control_weight": "number",
}
}
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info")
# fmt: on
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
}
},
}
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="The image to process")
# fmt: on
def run_processor(self, image):
# superclass just passes through image without processing
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
image=processed_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,
width = image_dto.width,
# height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
# fmt: on
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
return processed_image
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
coarse: bool = Field(default=False, description="Whether to use coarse mode")
# fmt: on
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse)
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = openpose_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,
)
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
# fmt: on
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d)
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
safe: bool = Field(default=False, description="Whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble)
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
# fmt: on
def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image

View File

@@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class CvInvocationConfig(BaseModel):
@@ -26,27 +26,24 @@ class CvInvocationConfig(BaseModel):
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv."""
# fmt: off
#fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
cv_mask = numpy.array(ImageOps.invert(mask))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
@@ -55,20 +52,18 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_dto = context.services.images.create(
image=image_inpainted,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@@ -4,29 +4,23 @@ from functools import partial
from typing import Literal, Optional, Union, get_args
import numpy as np
from diffusers import ControlNetModel
from torch import Tensor
import torch
from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@@ -58,11 +52,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@@ -83,57 +74,47 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt", "control_image" }
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
image_dto = context.services.images.create(
image=generate_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.outputs.set(image_name, context.graph_execution_state_id)
s = context.services.outputs.get(image_name)
print(s)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image,
)
@@ -169,8 +150,8 @@ class ImageToImageInvocation(TextToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
else context.services.images.get(
self.image.image_type, self.image.image_name
)
)
@@ -199,24 +180,26 @@ class ImageToImageInvocation(TextToImageInvocation):
# each time it is called. We only need the first one.
generator_output = next(outputs)
image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@@ -226,38 +209,16 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)"
)
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(
default=30, ge=1, description="The number of steps to use for seam inpaint"
)
tile_size: int = Field(
default=32, ge=1, description="The tile infill method size (px)"
)
infill_method: INFILL_METHODS = Field(
default=DEFAULT_INFILL_METHOD,
description="The method used to infill empty regions (px)",
)
inpaint_width: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The width of the inpaint region (px)",
)
inpaint_height: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The height of the inpaint region (px)",
)
inpaint_fill: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The solid infill method color",
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field(
default=0.0,
ge=0.0,
@@ -282,14 +243,14 @@ class InpaintInvocation(ImageToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
else context.services.images.get(
self.image.image_type, self.image.image_name
)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
# Handle invalid model parameter
@@ -315,20 +276,23 @@ class InpaintInvocation(ImageToImageInvocation):
# each time it is called. We only need the first one.
generator_output = next(outputs)
image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)

View File

@@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional, Union
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ..models.image import ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
@@ -41,14 +41,27 @@ class ImageOutput(BaseInvocationOutput):
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
@@ -67,20 +80,16 @@ class LoadImageInvocation(BaseInvocation):
type: Literal["load_image"] = "load_image"
# Inputs
image: Union[ImageField, None] = Field(
default=None, description="The image to load"
)
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
image = context.services.images.get(self.image_type, self.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
)
@@ -90,37 +99,32 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image"
# Inputs
image: Union[ImageField, None] = Field(
default=None, description="The image to show"
)
image: ImageField = Field(default=None, description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
)
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
type: Literal["img_crop"] = "img_crop"
type: Literal["crop"] = "crop"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
image: ImageField = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@@ -128,8 +132,8 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_crop = Image.new(
@@ -137,53 +141,49 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
image_crop.paste(image, (-self.x, -self.y))
image_dto = context.services.images.create(
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_crop,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
# fmt: off
type: Literal["img_paste"] = "img_paste"
type: Literal["paste"] = "paste"
# Inputs
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_origin, self.base_image.image_name
base_image = context.services.images.get(
self.base_image.image_type, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
context.services.images.get(self.mask.image_type, self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@@ -199,22 +199,20 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_dto = context.services.images.create(
image=new_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=new_image,
)
@@ -225,169 +223,47 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_mask = image.split()[-1]
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_dto = context.services.images.create(
image=image_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return MaskOutput(
mask=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
"""Gets a channel from an image."""
# fmt: off
type: Literal["img_chan"] = "img_chan"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
"""Converts an image to a different mode."""
# fmt: off
type: Literal["img_conv"] = "img_conv"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
# fmt: off
type: Literal["img_blur"] = "img_blur"
type: Literal["blur"] = "blur"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
blur = (
@@ -397,149 +273,36 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
blur_image = image.filter(blur)
image_dto = context.services.images.create(
image=blur_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]
PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""
# fmt: off
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""
# fmt: off
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)
resize_image = image.resize(
(width, height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["img_lerp"] = "img_lerp"
type: Literal["lerp"] = "lerp"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@@ -547,40 +310,36 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_dto = context.services.images.create(
image=lerp_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["img_ilerp"] = "img_ilerp"
type: Literal["ilerp"] = "ilerp"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32)
@@ -593,20 +352,16 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_dto = context.services.images.create(
image=ilerp_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@@ -1,17 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Union, get_args
from typing import Literal, Optional, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from ..models.image import ColorField, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@@ -125,40 +125,36 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
color: ColorField = Field(
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled = Image.alpha_composite(solid_bg, image)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
@@ -167,9 +163,7 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile"
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
@@ -179,8 +173,8 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
infilled = tile_fill_missing(
@@ -188,22 +182,20 @@ class InfillTileInvocation(BaseInvocation):
)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
@@ -212,13 +204,11 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
image: Optional[ImageField] = Field(default=None, description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if PatchMatch.patchmatch_available():
@@ -226,20 +216,18 @@ class InfillPatchMatchInvocation(BaseInvocation):
else:
raise ValueError("PatchMatch is not available on this system")
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)

View File

@@ -1,42 +1,33 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional, Union
import einops
from typing import Literal, Optional, Union, List
from compel import Compel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from .controlnet_image_processors import ControlField
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_file_storage import ResourceOrigin
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline, ControlNetModel
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
@@ -92,13 +83,13 @@ SAMPLER_NAME_VALUES = Literal[
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@@ -148,17 +139,12 @@ class NoiseInvocation(BaseInvocation):
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, noise)
context.services.latents.set(name, noise)
return build_noise_output(latents_name=name, latents=noise)
@@ -174,36 +160,20 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"tags": ["latents", "image"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
"model": "model"
}
},
}
@@ -229,17 +199,17 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler
)
# if isinstance(model, DiffusionPipeline):
# for component in [model.unet, model.vae]:
# configure_model_padding(component,
# self.seamless,
# self.seamless_axes
# )
# else:
# configure_model_padding(model,
# self.seamless,
# self.seamless_axes
# )
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
@@ -248,20 +218,11 @@ class TextToLatentsInvocation(BaseInvocation):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
compel = Compel(
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
extra=extra_conditioning_info,
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
@@ -271,81 +232,6 @@ class TextToLatentsInvocation(BaseInvocation):
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
@@ -360,26 +246,21 @@ class TextToLatentsInvocation(BaseInvocation):
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents)
context.services.latents.set(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@@ -390,7 +271,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
@@ -398,9 +279,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
"cfg_scale": "number",
"model": "model"
}
},
}
@@ -419,12 +298,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
@@ -439,7 +312,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
@@ -447,7 +319,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents)
context.services.latents.set(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@@ -484,30 +356,20 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
# what happened to metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
# new (post Image service refactor) way of using services to save image
# and gnenerate unique image_name
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
)
@@ -542,8 +404,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
context.services.latents.set(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@@ -573,8 +434,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
context.services.latents.set(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@@ -598,11 +458,8 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
# TODO: this only really needs the vae
@@ -621,6 +478,5 @@ class ImageToLatentsInvocation(BaseInvocation):
)
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
context.services.latents.set(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@@ -5,12 +5,7 @@ from typing import Literal
from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
class MathInvocationConfig(BaseModel):
@@ -27,30 +22,19 @@ class MathInvocationConfig(BaseModel):
class IntOutput(BaseInvocationOutput):
"""An integer output"""
# fmt: off
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
# fmt: on
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
#fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""
# fmt: off
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
@@ -58,12 +42,11 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
"""Subtracts two numbers"""
# fmt: off
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
@@ -71,12 +54,11 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
"""Multiplies two numbers"""
# fmt: off
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
@@ -84,12 +66,11 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
class DivideInvocation(BaseInvocation, MathInvocationConfig):
"""Divides two numbers"""
# fmt: off
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
@@ -97,13 +78,8 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
# fmt: off
#fmt: off
type: Literal["rand_int"] = "rand_int"
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(self.low, self.high))
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))

View File

@@ -1,237 +0,0 @@
import io
from typing import Literal, Optional, Any
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
"Linear": LinearInOut,
"QuadIn": QuadEaseIn,
"QuadOut": QuadEaseOut,
"QuadInOut": QuadEaseInOut,
"CubicIn": CubicEaseIn,
"CubicOut": CubicEaseOut,
"CubicInOut": CubicEaseInOut,
"QuarticIn": QuarticEaseIn,
"QuarticOut": QuarticEaseOut,
"QuarticInOut": QuarticEaseInOut,
"QuinticIn": QuinticEaseIn,
"QuinticOut": QuinticEaseOut,
"QuinticInOut": QuinticEaseInOut,
"SineIn": SineEaseIn,
"SineOut": SineEaseOut,
"SineInOut": SineEaseInOut,
"CircularIn": CircularEaseIn,
"CircularOut": CircularEaseOut,
"CircularInOut": CircularEaseInOut,
"ExponentialIn": ExponentialEaseIn,
"ExponentialOut": ExponentialEaseOut,
"ExponentialInOut": ExponentialEaseInOut,
"ElasticIn": ElasticEaseIn,
"ElasticOut": ElasticEaseOut,
"ElasticInOut": ElasticEaseInOut,
"BackIn": BackEaseIn,
"BackOut": BackEaseOut,
"BackInOut": BackEaseInOut,
"BounceIn": BounceEaseIn,
"BounceOut": BounceEaseOut,
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
start_step = int(np.round(self.num_steps * self.start_step_percent))
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
num_easing_steps = end_step - start_step + 1
# num_presteps = max(start_step - 1, 0)
num_presteps = start_step
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
prelist = list(num_presteps * [self.pre_start_value])
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
# # half_ease_duration = round(num_easing_steps - 1 / 2)
# half_ease_duration = round((num_easing_steps - 1) / 2)
# easing_function = easing_class(start=self.start_value,
# end=self.end_value,
# duration=half_ease_duration,
# )
#
# mirror_function = easing_class(start=self.end_value,
# end=self.start_value,
# duration=half_ease_duration,
# )
# for step_index in range(num_easing_steps):
# if step_index <= half_ease_duration:
# step_val = easing_function.ease(step_index)
# else:
# step_val = mirror_function.ease(step_index - half_ease_duration)
# easing_list.append(step_val)
# if log_diagnostics: logger.debug(step_index, step_val)
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
if self.show_easing_plot:
plt.figure()
plt.xlabel("Step")
plt.ylabel("Param Value")
plt.title("Per-Step Values Based On Easing: " + self.easing)
plt.bar(range(len(param_list)), param_list)
# plt.plot(param_list)
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(
collection=param_list
)

View File

@@ -3,7 +3,7 @@
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput, FloatOutput
from .math import IntOutput
# Pass-through parameter nodes - used by subgraphs
@@ -16,13 +16,3 @@ class ParamIntInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
#fmt: on
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)

View File

@@ -2,23 +2,21 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
#fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@@ -28,8 +26,8 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@@ -41,20 +39,18 @@ class RestoreFaceInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -4,22 +4,22 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
#fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
#fmt: on
# Schema customisation
class Config(InvocationConfig):
@@ -30,8 +30,8 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@@ -43,20 +43,18 @@ class UpscaleInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -2,77 +2,31 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image).
- INTERNAL: The resource was created by the application.
- EXTERNAL: The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
INTERNAL = "internal"
"""The resource was created by the application."""
EXTERNAL = "external"
"""The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
class InvalidOriginException(ValueError):
"""Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid resource origin."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image.
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
- MASK: The image is a mask image.
- CONTROL: The image is a ControlNet control image.
- USER: The image is a user-provide image.
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
"""
GENERAL = "general"
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask"
"""MASK: The image is a mask image."""
CONTROL = "control"
"""CONTROL: The image is a ControlNet control image."""
USER = "user"
"""USER: The image is a user-provide image."""
OTHER = "other"
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError):
"""Raised when a provided value is not a valid ImageCategory.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image category."):
super().__init__(message)
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
image_type: ImageType = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_origin", "image_name"]}
schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
@@ -83,11 +37,3 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@@ -1,93 +0,0 @@
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@@ -1,581 +0,0 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
'''Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has a top-level key of "InvokeAI" and subheadings for each of the
categories returned by `invokeai --help`. The file looks like this:
[file: invokeai.yaml]
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
embedding_dir: embeddings
lora_dir: loras
autoconvert_dir: null
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
Models:
model: stable-diffusion-1.5
embeddings: true
Memory/Performance:
xformers_enabled: false
sequential_guidance: false
precision: float16
max_loaded_models: 4
always_use_cpu: false
free_gpu_mem: false
Features:
nsfw_checker: true
restore: true
esrgan: true
patchmatch: true
internet_available: true
log_tokenization: false
Web Server:
host: 127.0.0.1
port: 8081
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf.parse_args(argv=['--xformers_enabled'])
It is also possible to set a value at initialization time. However, if
you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# True
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:
export INVOKEAI_port=8080
Order of precedence (from highest):
1) initialization options
2) command line options
3) environment variable options
4) config file options
5) pydantic defaults
Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker)
Computed properties:
The InvokeAIAppConfig object has a series of properties that
resolve paths relative to the runtime root directory. They each return
a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The InvokeAIAppConfig.get_config() function
does this:
config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root)
# Subclassing
If you wish to create a similar class, please subclass the
`InvokeAISettings` class and define a Literal field named "type",
which is set to the desired top-level name. For example, to create a
"InvokeBatch" configuration, define like this:
class InvokeBatch(InvokeAISettings):
type: Literal["InvokeBatch"] = "InvokeBatch"
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
This will now read and write from the "InvokeBatch" section of the
config file, look for environment variables named INVOKEBATCH_*, and
accept the command-line arguments `--node_count` and `--cpu_count`. The
two configs are kept in separate sections of the config file:
# invokeai.yaml
InvokeBatch:
Resources:
node_count: 1
cpu_count: 8
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
'''
from __future__ import annotations
import argparse
import pydoc
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
DB_FILE = Path('invokeai.db')
LEGACY_INIT_FILE = Path('invokeai.init')
class InvokeAISettings(BaseSettings):
'''
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
'''
initconf : ClassVar[DictConfig] = None
argparse_groups : ClassVar[Dict] = {}
def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt,name))
def to_yaml(self)->str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)['type'])[0]
field_dict = dict({type:dict()})
for name,field in self.__fields__.items():
if name in cls._excluded():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self,name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
if 'type' in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
initconf = cls.initconf.get(settings_stanza) \
if cls.initconf and settings_stanza in cls.initconf \
else OmegaConf.create()
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
for key,value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
category = field.field_info.extra.get("category","Uncategorized")
env_name = env_prefix + '_' + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
field.default = upcase_environ[env_name.upper()]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return 'Uncategorized'
@classmethod
def get_parser(cls)->ArgumentParser:
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self)->List[str]:
return ['type','initconf']
class Config:
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
case_sensitive = True
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
field_type = get_type_hints(cls).get(name)
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs='*',
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
def _find_root()->Path:
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif (
os.environ.get("VIRTUAL_ENV")
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
or
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
)
):
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
class InvokeAIAppConfig(InvokeAISettings):
'''
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
'''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
#fmt: on
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
'''
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization
'''
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
super().parse_args(argv)
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
super().parse_args(argv)
if self.singleton_init and not clobber:
hints = get_type_hints(self.__class__)
for k in self.singleton_init:
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property
def root_path(self)->Path:
'''
Path to the runtime root directory
'''
if self.root:
return Path(self.root).expanduser()
else:
return self.find_root()
@property
def root_dir(self)->Path:
'''
Alias for above.
'''
return self.root_path
def _resolve(self,partial_path:Path)->Path:
return (self.root_path / partial_path).resolve()
@property
def init_file_path(self)->Path:
'''
Path to invokeai.yaml
'''
return self._resolve(INIT_FILE)
@property
def output_path(self)->Path:
'''
Path to defaults outputs directory.
'''
return self._resolve(self.outdir)
@property
def db_path(self)->Path:
'''
Path to the invokeai.db file.
'''
return self._resolve(self.db_dir) / DB_FILE
@property
def model_conf_path(self)->Path:
'''
Path to models configuration file.
'''
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self)->Path:
'''
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
'''
return self._resolve(self.legacy_conf_dir)
@property
def cache_dir(self)->Path:
'''
Path to the global cache directory for HuggingFace hub-managed models
'''
return self.models_dir / "hub"
@property
def models_dir(self)->Path:
'''
Path to the models directory
'''
return self._resolve("models")
@property
def embedding_path(self)->Path:
'''
Path to the textual inversion embeddings directory.
'''
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def lora_path(self)->Path:
'''
Path to the LoRA models directory.
'''
return self._resolve(self.lora_dir) if self.lora_dir else None
@property
def controlnet_path(self)->Path:
'''
Path to the controlnet models directory.
'''
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
@property
def autoconvert_path(self)->Path:
'''
Path to the directory containing models to be imported automatically at startup.
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
'''
Path to the GFPGAN model.
'''
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self)->bool:
"""Return true if precision set to float32"""
return self.precision=='float32'
@property
def disable_xformers(self)->bool:
"""Return true if xformers_enabled is false"""
return not self.xformers_enabled
@property
def try_patchmatch(self)->bool:
"""Return true if patchmatch true"""
return self.patchmatch
@staticmethod
def find_root()->Path:
'''
Choose the runtime root directory when not specified on command line or
init file.
'''
return _find_root()
class PagingArgumentParser(argparse.ArgumentParser):
'''
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
'''
def print_help(self, file=None):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
'''
Legacy function which returns InvokeAIAppConfig.get_config()
'''
return InvokeAIAppConfig.get_config(**kwargs)

View File

@@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from invokeai.app.models.image import ProgressImage
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp

View File

@@ -60,33 +60,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None
return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
else:
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
# If the type is a List
if get_origin(t) is list:
return True
# If the type is a Union
elif t_args:
# Check if any of the types in the Union is a List
for arg in t_args:
if get_origin(arg) is list:
return True
return False
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
@@ -112,8 +85,7 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type):
return True
# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
if not issubclass(from_type, to_type):
return False
else:
return False
@@ -163,7 +135,6 @@ class GraphInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
@@ -191,7 +162,6 @@ class IterateInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(
@@ -391,7 +361,7 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
raise InvalidEdgeError("One or both nodes don't exist")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
@@ -402,41 +372,41 @@ class Graph(BaseModel):
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
raise InvalidEdgeError(f'Fields are incompatible')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
raise InvalidEdgeError(f'Collector output type does not match collector input type')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
raise InvalidEdgeError(f'Collector input type does not match collector output type')
def has_node(self, node_path: str) -> bool:
@@ -722,11 +692,7 @@ class Graph(BaseModel):
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
# if not all((get_origin(f) == list for f in output_fields)):
# return False
# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
if not all((get_origin(f) == list for f in output_fields)):
return False
# Verify that all outputs match the input type (are a base class or the same class)
@@ -745,13 +711,6 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:
@@ -857,9 +816,11 @@ class GraphExecutionState(BaseModel):
if next_node is None:
prepared_id = self._prepare()
# Prepare as many nodes as we can
while prepared_id is not None:
prepared_id = self._prepare()
# TODO: prepare multiple nodes at once?
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
# prepared_id = self._prepare()
if prepared_id is not None:
next_node = self._get_next_node()
# Get values from edges
@@ -1006,30 +967,14 @@ class GraphExecutionState(BaseModel):
# Get flattened source graph
g = self.graph.nx_graph_flat()
# Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
# Find next unprepared node where all source nodes are executed
sorted_nodes = nx.topological_sort(g)
next_node_id = next(
(
n
for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping
# exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
and all((e[0] in self.executed for e in g.in_edges(n)))
),
None,
)
@@ -1126,22 +1071,9 @@ class GraphExecutionState(BaseModel):
)
def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
if next_node is None:
return None

View File

@@ -1,204 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
from send2trash import send2trash
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
# TODO: Should these excpetions subclass existing python exceptions?
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
# 500 internal server error. I don't like having this method on the service.
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
except FileNotFoundError as e:
raise ImageFileNotFoundException from e
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_origin, image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path)
return abspath
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@@ -1,419 +0,0 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageRecordChanges,
deserialize_image_record,
)
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
# fmt: off
items: list[T] = Field(description="Items")
offset: int = Field(description="Offset from which to retrieve items")
limit: int = Field(description="Limit of items to get")
total: int = Field(description="Total number of items in result")
# fmt: on
# TODO: Should these excpetions subclass existing python exceptions?
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record."""
pass
@abstractmethod
def save(
self,
image_name: str,
image_origin: ResourceOrigin,
image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
"""Saves an image record."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
WHERE image_name = old.image_name;
END;
"""
)
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
query_conditions = ""
query_params = []
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_params.append(image_origin.value)
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
image_name: str,
image_origin: ResourceOrigin,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
)
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata_json,
is_intermediate,
),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()

View File

@@ -0,0 +1,274 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from glob import glob
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@@ -1,393 +0,0 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, TYPE_CHECKING, Union
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
InvalidImageCategoryException,
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""High-level service for image management."""
@abstractmethod
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass
@abstractmethod
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.files.save(
image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_origin, image_name, True
)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
)
return image_dto
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_origin, r.image_name, True
),
),
results.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_origin: ResourceOrigin, image_name: str):
try:
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@@ -1,60 +1,59 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
from typing import types
from invokeai.app.services.outputs_session_storage import OutputsSessionStorageABC
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
class InvocationServices:
"""Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageService"
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
outputs: OutputsSessionStorageABC
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
def __init__(
self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageService",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings",
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
outputs: OutputsSessionStorageABC,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.outputs = outputs
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@@ -22,8 +22,7 @@ class Invoker:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
# Get the next invocation
invocation = graph_execution_state.next()

View File

@@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
pass
@abstractmethod
def save(self, name: str, data: torch.Tensor) -> None:
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
@@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__set_cache(name, latent)
return latent
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
@@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)

View File

@@ -1,142 +1,105 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Union
import networkx as nx
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
class MetadataServiceBase(ABC):
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
@abstractmethod
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
pass
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
try:
info = image.info.get("invokeai")
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@@ -2,25 +2,27 @@ import os
import sys
import torch
from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
)
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{config.root}"')
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@@ -30,7 +32,20 @@ def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> Mod
import diffusers
diffusers.logging.set_verbosity_error()
embedding_path = config.embedding_path
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = config.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
@@ -43,11 +58,11 @@ def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> Mod
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.model_conf_path),
OmegaConf.load(config.conf),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = embedding_path,
embedding_path = Path(embedding_path),
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
@@ -58,10 +73,12 @@ def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> Mod
# try to autoconvert new models
# autoimport new .ckpt files
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
if path := config.autoconvert:
model_manager.autoconvert_weights(
conf_path=config.conf,
weights_directory=path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):

View File

@@ -1,148 +0,0 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
"""The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field(
description="The updated timestamp of the image."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
"""The URL of the image's thumbnail."""
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
)
def deserialize_image_record(image_dict: dict) -> ImageRecord:
"""Deserializes an image record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord(
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
)

View File

@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
class PaginatedStringResults(GenericModel):
"""Paginated results"""
#fmt: off
items: list[str] = Field(description="Session IDs")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
#fmt: on
class OutputsSessionStorageABC(ABC):
_on_changed_callbacks: list[Callable[[str], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""
@abstractmethod
def get(self, output_id: str) -> str:
pass
@abstractmethod
def set(self, output_id: str, session_id: str) -> None:
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
pass
def on_changed(self, on_changed: Callable[[str], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, foreign_key_value: str) -> None:
for callback in self._on_changed_callbacks:
callback(foreign_key_value)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@@ -0,0 +1,162 @@
import json
import sqlite3
from threading import Lock
from typing import Union
from invokeai.app.services.outputs_session_storage import (
OutputsSessionStorageABC,
PaginatedStringResults,
)
sqlite_memory = ":memory:"
class OutputsSqliteItemStorage(OutputsSessionStorageABC):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(
self,
filename: str,
):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS outputs (
id TEXT NOT NULL PRIMARY KEY,
session_id TEXT NOT NULL
);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS outputs_id ON outputs(id);"""
)
finally:
self._lock.release()
def set(self, output_id: str, session_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO outputs (id, session_id) VALUES (?, ?);""",
(output_id, session_id),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(output_id)
def get(self, output_id: str) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id = ?;
""",
(output_id,),
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, output_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM outputs WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(output_id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id;
""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ? LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)

View File

@@ -1,30 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename

View File

@@ -26,6 +26,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution

View File

@@ -1,34 +0,0 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail."""
pass
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"

View File

@@ -1,15 +0,0 @@
from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support additional features in Enums.
- `in` operator support: `'value' in MyEnum -> bool`
"""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True

View File

@@ -6,14 +6,6 @@ def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max

View File

@@ -1,5 +1,5 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator

View File

@@ -1,6 +1,7 @@
"""
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
@@ -11,3 +12,5 @@ from .generator import (
)
from .model_management import ModelManager, SDModelComponent
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

1391
invokeai/backend/args.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -12,17 +12,17 @@ print("Loading Python libraries...\n",file=sys.stderr)
import argparse
import io
import os
import re
import shutil
import textwrap
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
from urllib import request
import npyscreen
import torch
import transformers
from diffusers import AutoencoderKL
from huggingface_hub import HfFolder
@@ -35,53 +35,57 @@ from transformers import (
CLIPTextModel,
CLIPTokenizer,
)
import invokeai.configs as configs
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
from ...frontend.install.model_install import addModelsForm, process_and_execute
from ...frontend.install.widgets import (
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
CyclingForm,
MIN_COLS,
MIN_LINES,
)
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
from ..args import PRECISION_CHOICES, Args
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
from .model_install_backend import (
default_dataset,
download_from_hf,
hf_download_with_resume,
recommended_datasets,
UserSelections,
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
PRECISION_CHOICES = ['auto','float16','float32','autocast']
Default_config_file = Path(global_config_dir()) / "models.yaml"
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
Datasets = OmegaConf.load(Dataset_path)
# minimum size for the UI
MIN_COLS = 135
MIN_LINES = 45
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
# or renaming it and then running invokeai-configure again.
# Place frequently-used startup commands here, one or more per line.
# Examples:
# --outdir=D:\data\images
# --no-nsfw_checker
# --web --host=0.0.0.0
# --steps=20
# -Ak_euler_a -C10.0
"""
logger=None
# --------------------------------------------
def postscript(errors: None):
@@ -92,13 +96,14 @@ If you installed manually from source or with 'pip install': activate the virtua
then run one of the following commands to start InvokeAI.
Web UI:
invokeai-web
invokeai --web # (connect to http://localhost:9090)
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
Command-line client:
Command-line interface:
invokeai
If you installed using an installation script, run:
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
Add the '--help' argument to see all of the command-line switches available for use.
"""
@@ -210,11 +215,16 @@ def download_realesrgan():
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
)
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
wdn_model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
def download_gfpgan():
@@ -233,8 +243,8 @@ def download_gfpgan():
"./models/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], config.root_path / model[1]
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
# ---------------------------------------------
@@ -243,8 +253,8 @@ def download_codeformer():
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = config.root_path / "models/codeformer/codeformer.pth"
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
download_with_progress_bar(model_url, model_dest, "CodeFormer")
# ---------------------------------------------
@@ -285,7 +295,7 @@ def download_vaes():
# first the diffusers version
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=config.cache_dir,
cache_dir=global_cache_dir("hub"),
)
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f"download of {repo_id} failed")
@@ -296,7 +306,7 @@ def download_vaes():
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(config.root_path / Model_dir / Weights_dir),
model_dir=str(Globals.root / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
@@ -311,24 +321,25 @@ def get_root(root: str = None) -> str:
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return str(config.root_path)
return Globals.root
# -------------------------------------
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
class editOptsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
first_time = not (config.root_path / 'invokeai.yaml').exists()
first_time = not (Globals.root / Globals.initfile).exists()
access_token = HfFolder.get_token()
window_width, window_height = get_terminal_size()
label = """Configure startup settings. You can come back and change these later.
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
Use cursor arrows to make a checkbox selection, and space to toggle.
"""
for i in textwrap.wrap(label,width=window_width-6):
for i in [
"Configure startup settings. You can come back and change these later.",
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
"Use cursor arrows to make a checkbox selection, and space to toggle.",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@@ -355,7 +366,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
value=old_opts.outdir or str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
@@ -370,21 +381,22 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
editable=False,
color="CONTROL",
)
self.nsfw_checker = self.add_widget_intelligent(
self.safety_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.nsfw_checker,
value=old_opts.safety_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1
label = """If you have an account at HuggingFace you may optionally paste your access token here
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
"""
for line in textwrap.wrap(label,width=window_width-6):
for i in [
"If you have an account at HuggingFace you may paste your access token here",
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
"See https://huggingface.co/settings/tokens",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=line,
value=i,
editable=False,
color="CONTROL",
)
@@ -423,10 +435,17 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
relx=5,
scroll_exit=True,
)
self.xformers_enabled = self.add_widget_intelligent(
self.xformers = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Enable xformers support if available",
value=old_opts.xformers_enabled,
value=old_opts.xformers,
relx=5,
scroll_exit=True,
)
self.ckpt_convert = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Load legacy checkpoint models into memory as diffusers models",
value=old_opts.ckpt_convert,
relx=5,
scroll_exit=True,
)
@@ -461,41 +480,19 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
value="Directory containing embedding/textual inversion files:",
editable=False,
color="CONTROL",
)
self.embedding_dir = self.add_widget_intelligent(
self.embedding_path = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" Textual Inversion Embeddings:",
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_embedding_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
)
self.lora_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" LoRA and LyCORIS:",
value=str(default_lora_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True,
)
self.controlnet_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=" ControlNets:",
value=str(default_controlnet_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
@@ -508,11 +505,11 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
scroll_exit=True,
)
self.nextrely -= 1
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
https://huggingface.co/spaces/CompVis/stable-diffusion-license
"""
for i in textwrap.wrap(label,width=window_width-6):
for i in [
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
@@ -551,7 +548,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
self.editing = False
else:
self.editing = True
def validate_field_values(self, opt: Namespace) -> bool:
bad_fields = []
if not opt.license_acceptance:
@@ -562,9 +559,9 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
)
if not Path(opt.embedding_dir).parent.exists():
if not Path(opt.embedding_path).parent.exists():
bad_fields.append(
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
)
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:\n"
@@ -579,24 +576,20 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
new_opts = Namespace()
for attr in [
"outdir",
"nsfw_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers_enabled",
"always_use_cpu",
"embedding_dir",
"lora_dir",
"controlnet_dir",
"outdir",
"safety_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers",
"always_use_cpu",
"embedding_path",
"ckpt_convert",
]:
setattr(new_opts, attr, getattr(self, attr).value)
new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
# widget library workaround to make max_loaded_models an int rather than a float
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
return new_opts
@@ -615,7 +608,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
"MAIN",
editOptsForm,
name="InvokeAI Startup Options",
cycle_widgets=True,
)
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm(
@@ -623,7 +615,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
addModelsForm,
name="Install Stable Diffusion Models",
multipage=True,
cycle_widgets=True,
)
def new_opts(self):
@@ -637,14 +628,18 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config()
opts = Args().parse_args([])
outdir = Path(opts.outdir)
if not outdir.is_absolute():
opts.outdir = str(Globals.root / opts.outdir)
if not init_file.exists():
opts.nsfw_checker = True
opts.safety_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> UserSelections:
return UserSelections(
install_models=default_dataset()
def default_user_selections(program_opts: Namespace) -> Namespace:
return Namespace(
starter_models=default_dataset()
if program_opts.default_only
else recommended_datasets()
if program_opts.yes_to_all
@@ -652,27 +647,26 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
purge_deleted_models=False,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
# -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False):
def initialize_rootdir(root: str, yes_to_all: bool = False):
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in (
"models",
"configs",
"embeddings",
"databases",
"loras",
"controlnets",
"text-inversion-output",
"text-inversion-training-data",
"models",
"configs",
"embeddings",
"text-inversion-output",
"text-inversion-training-data",
):
os.makedirs(os.path.join(root, name), exist_ok=True)
configs_src = Path(configs.__path__[0])
configs_dest = root / "configs"
configs_dest = Path(root) / "configs"
if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
@@ -683,17 +677,8 @@ def run_console_ui(
) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
import torch
torch.multiprocessing.set_start_method("spawn")
set_min_terminal_size(MIN_COLS, MIN_LINES)
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
if editApp.user_cancelled:
@@ -705,66 +690,70 @@ def run_console_ui(
# -------------------------------------
def write_opts(opts: Namespace, init_file: Path):
"""
Update the invokeai.yaml file with values from current settings.
Update the invokeai.init file with values from opts Namespace
"""
# this will load current settings
new_config = InvokeAIAppConfig.get_config()
new_config.root = config.root
for key,value in opts.__dict__.items():
if hasattr(new_config,key):
setattr(new_config,key,value)
# touch file if it doesn't exist
if not init_file.exists():
with open(init_file, "w") as f:
f.write(INIT_FILE_PREAMBLE)
# We want to write in the changed arguments without clobbering
# any other initialization values the user has entered. There is
# no good way to do this because of the one-way nature of
# argparse: i.e. --outdir could be --outdir, --out, or -o
# initfile needs to be replaced with a fully structured format
# such as yaml; this is a hack that will work much of the time
args_to_skip = re.compile(
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
)
# fix windows paths
opts.outdir = opts.outdir.replace("\\", "/")
opts.embedding_path = opts.embedding_path.replace("\\", "/")
new_file = f"{init_file}.new"
try:
lines = [x.strip() for x in open(init_file, "r").readlines()]
with open(new_file, "w") as out_file:
for line in lines:
if len(line) > 0 and not args_to_skip.match(line):
out_file.write(line + "\n")
out_file.write(
f"""
--outdir={opts.outdir}
--embedding_path={opts.embedding_path}
--precision={opts.precision}
--max_loaded_models={int(opts.max_loaded_models)}
--{'no-' if not opts.safety_checker else ''}nsfw_checker
--{'no-' if not opts.xformers else ''}xformers
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
{'--always_use_cpu' if opts.always_use_cpu else ''}
"""
)
except OSError as e:
print(f"** An error occurred while writing the init file: {str(e)}")
os.replace(new_file, init_file)
if opts.hf_token:
HfLogin(opts.hf_token)
with open(init_file,'w', encoding='utf-8') as file:
file.write(new_config.to_yaml())
# -------------------------------------
def default_output_dir() -> Path:
return config.root_path / "outputs"
return Globals.root / "outputs"
# -------------------------------------
def default_embedding_dir() -> Path:
return config.root_path / "embeddings"
return Globals.root / "embeddings"
# -------------------------------------
def default_lora_dir() -> Path:
return config.root_path / "loras"
# -------------------------------------
def default_controlnet_dir() -> Path:
return config.root_path / "controlnets"
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile)
opt.hf_token = HfFolder.get_token()
write_opts(opt, initfile)
# -------------------------------------
# Here we bring in
# the legacy Args object in order to parse
# the old init file and write out the new
# yaml format.
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
if hasattr(old,attr):
setattr(new,attr,getattr(old,attr))
# a few places where the field names have changed and we have to
# manually add in the new names/values
new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers
new.conf_path = old.conf
new.embedding_dir = old.embedding_path
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
outfile.write(new.to_yaml())
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
# -------------------------------------
def main():
@@ -820,13 +809,8 @@ def main():
)
opt = parser.parse_args()
invoke_args = []
if opt.root:
invoke_args.extend(['--root',opt.root])
if opt.full_precision:
invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
# setting a global here
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
errors = set()
@@ -834,26 +818,19 @@ def main():
models_to_download = default_user_selections(opt)
# We check for to see if the runtime directory is correctly initialized.
old_init_file = config.root_path / 'invokeai.init'
new_init_file = config.root_path / 'invokeai.yaml'
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists():
initialize_rootdir(config.root_path, opt.yes_to_all)
init_file = Path(Globals.root, Globals.initfile)
if not init_file.exists() or not global_config_file().exists():
initialize_rootdir(Globals.root, opt.yes_to_all)
if opt.yes_to_all:
write_default_options(opt, new_init_file)
write_default_options(opt, init_file)
init_options = Namespace(
precision="float32" if opt.full_precision else "float16"
)
else:
init_options, models_to_download = run_console_ui(opt, new_init_file)
init_options, models_to_download = run_console_ui(opt, init_file)
if init_options:
write_opts(init_options, new_init_file)
write_opts(init_options, init_file)
else:
print(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
@@ -863,7 +840,7 @@ def main():
if opt.skip_support_models:
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
else:
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
print("\n** DOWNLOADING SUPPORT MODELS **")
download_bert()
download_sd1_clip()
download_sd2_clip()
@@ -881,8 +858,6 @@ def main():
process_and_execute(opt, models_to_download)
postscript(errors=errors)
if not opt.yes_to_all:
input('Press any key to continue...')
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@@ -6,30 +6,26 @@ import re
import shutil
import sys
import warnings
from dataclasses import dataclass,field
from pathlib import Path
from tempfile import TemporaryFile
from typing import List, Dict, Callable
from typing import List
import requests
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_url, HfFolder
from huggingface_hub import hf_hub_url
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from ..globals import Globals, global_cache_dir, global_config_dir
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
@@ -39,9 +35,6 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
# logger
logger = InvokeAILogger.getLogger(name='InvokeAI')
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
@@ -52,98 +45,58 @@ Config_preamble = """
# was trained on.
"""
@dataclass
class ModelInstallList:
'''Class for listing models to be installed/removed'''
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class UserSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
purge_deleted_models: bool=field(default_factory=list)
install_cn_models: List[str] = field(default_factory=list)
remove_cn_models: List[str] = field(default_factory=list)
install_lora_models: List[str] = field(default_factory=list)
remove_lora_models: List[str] = field(default_factory=list)
install_ti_models: List[str] = field(default_factory=list)
remove_ti_models: List[str] = field(default_factory=list)
scan_directory: Path = None
autoscan_on_startup: bool=False
import_model_paths: str=None
def default_config_file():
return config.model_conf_path
return Path(global_config_dir()) / "models.yaml"
def sd_configs():
return config.legacy_conf_path
return Path(global_config_dir()) / "stable-diffusion"
def initial_models():
global Datasets
if Datasets:
return Datasets
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
return (Datasets := OmegaConf.load(Dataset_path))
def install_requested_models(
diffusers: ModelInstallList = None,
controlnet: ModelInstallList = None,
lora: ModelInstallList = None,
ti: ModelInstallList = None,
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
model_config_file_callback: Callable[[Path],Path] = None
install_initial_models: List[str] = None,
remove_models: List[str] = None,
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
):
"""
Entry point for installing/deleting starter models, or installing external models.
"""
access_token = HfFolder.get_token()
config_file_path = config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path, "w")
# prevent circular import here
from ..model_management import ModelManager
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
if controlnet:
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
if lora:
model_manager.install_lora_models(lora.install_models, access_token=access_token)
model_manager.delete_lora_models(lora.remove_models)
if remove_models and len(remove_models) > 0:
print("== DELETING UNCHECKED STARTER MODELS ==")
for model in remove_models:
print(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if ti:
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
if diffusers:
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("Processing requested deletions")
for model in diffusers.remove_models:
logger.info(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("Installing requested models")
downloaded_paths = download_weight_datasets(
models=diffusers.install_models,
access_token=None,
precision=precision,
)
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
if len(successful) > 0:
update_config_file(successful, config_file_path)
if len(successful) < len(diffusers.install_models):
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
if install_initial_models and len(install_initial_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==")
successfully_downloaded = download_weight_datasets(
models=install_initial_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models):
print("** Some of the model downloads were not successful")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
@@ -154,14 +107,12 @@ def install_requested_models(
external_models.append(str(scan_directory))
if len(external_models) > 0:
logger.info("INSTALLING EXTERNAL MODELS")
print("== INSTALLING EXTERNAL MODELS ==")
for path_url_or_repo in external_models:
try:
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
model_manager.heuristic_import(
path_url_or_repo,
commit_to_conf=config_file_path,
config_file_callback = model_config_file_callback,
)
except KeyboardInterrupt:
sys.exit(-1)
@@ -169,22 +120,17 @@ def install_requested_models(
pass
if scan_at_startup and scan_directory.is_dir():
update_autoconvert_dir(scan_directory)
else:
update_autoconvert_dir(None)
def update_autoconvert_dir(autodir: Path):
'''
Update the "autoconvert_dir" option in invokeai.yaml
'''
invokeai_config_path = config.init_file_path
conf = OmegaConf.load(invokeai_config_path)
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
yaml = OmegaConf.to_yaml(conf)
tmpfile = invokeai_config_path.parent / "new_config.tmp"
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml)
tmpfile.replace(invokeai_config_path)
argument = "--autoconvert"
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f"{Globals.initfile}.new")
directory = str(scan_directory).replace("\\", "/")
with open(initfile, "r") as input:
with open(replacement, "w") as output:
while line := input.readline():
if not line.startswith(argument):
output.writelines([line])
output.writelines([f"{argument} {directory}"])
os.replace(replacement, initfile)
# -------------------------------------
@@ -196,21 +142,33 @@ def yes_or_no(prompt: str, default_yes=True):
else:
return response[0] in ("y", "Y")
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
return root
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return Globals.root
# ---------------------------------------------
def recommended_datasets() -> List['str']:
datasets = set()
def recommended_datasets() -> dict:
datasets = dict()
for ds in initial_models().keys():
if initial_models()[ds].get("recommended", False):
datasets.add(ds)
return list(datasets)
datasets[ds] = True
return datasets
# ---------------------------------------------
def default_dataset() -> dict:
datasets = set()
datasets = dict()
for ds in initial_models().keys():
if initial_models()[ds].get("default", False):
datasets.add(ds)
return list(datasets)
datasets[ds] = True
return datasets
# ---------------------------------------------
@@ -225,14 +183,14 @@ def all_datasets() -> dict:
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
logger.warning(
print(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
logger.warning(f"model.ckpt => {new_name}")
print(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
@@ -245,7 +203,7 @@ def download_weight_datasets(
migrate_models_ckpt()
successful = dict()
for mod in models:
logger.info(f"Downloading {mod}:")
print(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
@@ -266,10 +224,11 @@ def _download_repo_or_file(
)
return path
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
@@ -280,12 +239,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, **kwargs
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
):
logger = InvokeAILogger.getLogger('InvokeAI')
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
path = config.cache_dir
path = global_cache_dir(cache_subdir)
model = model_class.from_pretrained(
model_name,
cache_dir=path,
@@ -316,10 +272,10 @@ def _download_diffusion_weights(
**extra_args,
)
except OSError as e:
if 'Revision Not Found' in str(e):
if str(e).startswith("fp16 is not a valid"):
pass
else:
logger.error(str(e))
print(f"An unexpected error occurred while downloading the model: {e})")
if path:
break
return path
@@ -327,13 +283,9 @@ def _download_diffusion_weights(
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
repo_id: str, model_dir: str, model_name: str, access_token: str = None
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
model_dest = Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name)
@@ -353,19 +305,20 @@ def hf_download_with_resume(
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
print(f"* {model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
elif exist_size > 0:
logger.info(f"{model_name}: partial file found. Resuming...")
print(f"* {model_name}: partial file found. Resuming...")
else:
logger.info(f"{model_name}: Downloading...")
print(f"* {model_name}: Downloading...")
try:
if total < 2000:
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
return None
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
@@ -378,7 +331,7 @@ def hf_download_with_resume(
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
print(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest
@@ -403,8 +356,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
try:
backup = None
if os.path.exists(config_file):
logger.warning(
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
print(
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
@@ -421,16 +374,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
new_config.write(tmp.read())
except Exception as e:
logger.error(f"Error creating config file {config_file}: {str(e)}")
print(f"**Error creating config file {config_file}: {str(e)} **")
if backup is not None:
logger.info("restoring previous config file")
print("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
logger.info(f"Successfully created new configuration file {config_file}")
print(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
@@ -464,7 +417,7 @@ def new_config_file_contents(
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=config.root_dir
successfully_downloaded[model], start=Globals.root
)
stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
@@ -497,14 +450,14 @@ def delete_weights(model_name: str, conf_stanza: dict):
if re.match("/VAE/", conf_stanza.get("config")):
return
logger.warning(
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
print(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
if not weights.is_absolute():
weights = config.root_dir / weights
weights = Path(Globals.root) / weights
try:
weights.unlink()
except OSError as e:
logger.error(str(e))
print(str(e))

1234
invokeai/backend/generate.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -75,11 +75,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
):
self.model_info=model_info
self.params=params
self.kwargs = kwargs
def generate(self,
prompt: str='',
@@ -120,12 +118,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
model=model,
scheduler_name=generator_args.get('scheduler')
)
# get conditioning from prompt via Compel package
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
@@ -281,7 +276,7 @@ class Generator:
precision: str
model: DiffusionPipeline
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
def __init__(self, model: DiffusionPipeline, precision: str):
self.model = model
self.precision = precision
self.seed = None

View File

@@ -196,7 +196,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise, seed=None)
result = make_image(seam_noise, seed)
return result

View File

@@ -4,10 +4,6 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
import PIL.Image
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
@@ -17,13 +13,8 @@ from .base import Generator
class Txt2Img(Generator):
def __init__(self, model, precision,
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(
@@ -51,12 +42,9 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height'
"""
self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
@@ -73,37 +61,6 @@ class Txt2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = self.model.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
@@ -111,7 +68,6 @@ class Txt2Img(Generator):
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
**kwargs,
)
if (

122
invokeai/backend/globals.py Normal file
View File

@@ -0,0 +1,122 @@
"""
invokeai.backend.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains
the attributes:
- root - the root directory under which "models" and "outputs" can be found
- initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available
"""
import os
import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
Globals = Namespace()
# Where to look for the initialization file and other key components
Globals.initfile = "invokeai.init"
Globals.models_file = "models.yaml"
Globals.models_dir = "models"
Globals.config_dir = "configs"
Globals.autoscan_dir = "weights"
Globals.converted_ckpts_dir = "converted_ckpts"
# Set the default root directory. This can be overwritten by explicitly
# passing the `--root <directory>` argument on the command line.
# logic is:
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
# 3) use ~/invokeai
if os.environ.get("INVOKEAI_ROOT"):
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
elif (
os.environ.get("VIRTUAL_ENV")
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
):
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
else:
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
# Try loading patchmatch
Globals.try_patchmatch = True
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
Globals.always_use_cpu = False
# Whether the internet is reachable for dynamic downloads
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# Whether to disable xformers
Globals.disable_xformers = False
# Low-memory tradeoff for guidance calculations.
Globals.sequential_guidance = False
# whether we are forcing full precision
Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = True
# logging tokenization everywhere
Globals.log_tokenization = False
def global_config_file() -> Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir() -> Path:
return Path(Globals.root, Globals.config_dir)
def global_models_dir() -> Path:
return Path(Globals.root, Globals.models_dir)
def global_autoscan_dir() -> Path:
return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir() -> Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir: Union[str, Path]):
Globals.root = root_dir
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
"""
Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing
for Hugging Face-style conventions. Currently, Hugging Face has
moved all models into the "hub" subfolder, so for any pretrained
HF model, use:
global_cache_dir('hub')
The legacy location for transformers used to be global_cache_dir('transformers')
and global_cache_dir('diffusers') for diffusers.
"""
home: str = os.getenv("HF_HOME")
if home is None:
home = os.getenv("XDG_CACHE_HOME")
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + "huggingface"
if home is not None:
return Path(home, subdir)
else:
return Path(Globals.root, "models", subdir)

View File

@@ -6,8 +6,7 @@ be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
from invokeai.backend.globals import Globals
class PatchMatch:
"""
@@ -24,7 +23,7 @@ class PatchMatch:
def _load_patch_match(self):
if self.tried_load:
return
if config.try_patchmatch:
if Globals.try_patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:

View File

@@ -33,11 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
@@ -88,10 +88,10 @@ class Txt2Mask(object):
# BUG: we are not doing anything with the device option at this time
self.device = device
self.processor = AutoProcessor.from_pretrained(
CLIPSEG_MODEL, cache_dir=config.cache_dir
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=config.cache_dir
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
)
@torch.no_grad()

View File

@@ -1,390 +0,0 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
import argparse
import shlex
from argparse import ArgumentParser
SAMPLER_CHOICES = [
"ddim",
"ddpm",
"deis",
"lms",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
]
PRECISION_CHOICES = [
"auto",
"float32",
"autocast",
"float16",
]
class FileArgumentParser(ArgumentParser):
"""
Supports reading defaults from an init file.
"""
def convert_arg_line_to_args(self, arg_line):
return shlex.split(arg_line, comments=True)
legacy_parser = FileArgumentParser(
description=
"""
Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
""",
fromfile_prefix_chars='@',
)
general_group = legacy_parser.add_argument_group('General')
model_group = legacy_parser.add_argument_group('Model selection')
file_group = legacy_parser.add_argument_group('Input/output')
web_server_group = legacy_parser.add_argument_group('Web server')
render_group = legacy_parser.add_argument_group('Rendering')
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
general_group.add_argument(
'--version','-V',
action='store_true',
help='Print InvokeAI version number'
)
model_group.add_argument(
'--root_dir',
default=None,
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
)
model_group.add_argument(
'--config',
'-c',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--weight_dirs',
nargs='+',
type=str,
help='List of one or more directories that will be auto-scanned for new model weights to import',
)
model_group.add_argument(
'--png_compression','-z',
type=int,
default=6,
choices=range(0,9),
dest='png_compression',
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
)
model_group.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
)
model_group.add_argument(
'--free_gpu_mem',
dest='free_gpu_mem',
action='store_true',
help='Force free gpu memory before final decoding',
)
model_group.add_argument(
'--sequential_guidance',
dest='sequential_guidance',
action='store_true',
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
"at the expense of speed",
)
model_group.add_argument(
'--xformers',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable/disable xformers support (default enabled if installed)',
)
model_group.add_argument(
"--always_use_cpu",
dest="always_use_cpu",
action="store_true",
help="Force use of CPU even if GPU is available"
)
model_group.add_argument(
'--precision',
dest='precision',
type=str,
choices=PRECISION_CHOICES,
metavar='PRECISION',
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--ckpt_convert',
action=argparse.BooleanOptionalAction,
dest='ckpt_convert',
default=True,
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
)
model_group.add_argument(
'--internet',
action=argparse.BooleanOptionalAction,
dest='internet_available',
default=True,
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
)
model_group.add_argument(
'--nsfw_checker',
'--safety_checker',
action=argparse.BooleanOptionalAction,
dest='safety_checker',
default=False,
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
)
model_group.add_argument(
'--autoimport',
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
)
model_group.add_argument(
'--autoconvert',
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
)
model_group.add_argument(
'--patchmatch',
action=argparse.BooleanOptionalAction,
default=True,
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
)
file_group.add_argument(
'--from_file',
dest='infile',
type=str,
help='If specified, load prompts from this file',
)
file_group.add_argument(
'--outdir',
'-o',
type=str,
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
default='outputs',
)
file_group.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument(
'-s',
'--steps',
type=int,
default=50,
help='Number of steps'
)
render_group.add_argument(
'-W',
'--width',
type=int,
help='Image width, multiple of 64',
)
render_group.add_argument(
'-H',
'--height',
type=int,
help='Image height, multiple of 64',
)
render_group.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
render_group.add_argument(
'--sampler',
'-A',
'-m',
dest='sampler_name',
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'--log_tokenization',
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
)
render_group.add_argument(
'-f',
'--strength',
type=float,
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
render_group.add_argument(
'-T',
'-fit',
'--fit',
action=argparse.BooleanOptionalAction,
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
render_group.add_argument(
'--grid',
'-g',
action=argparse.BooleanOptionalAction,
help='generate a grid'
)
render_group.add_argument(
'--embedding_directory',
'--embedding_path',
dest='embedding_path',
default='embeddings',
type=str,
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
)
render_group.add_argument(
'--lora_directory',
dest='lora_path',
default='loras',
type=str,
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
)
render_group.add_argument(
'--embeddings',
action=argparse.BooleanOptionalAction,
default=True,
help='Enable embedding directory (default). Use --no-embeddings to disable.',
)
render_group.add_argument(
'--enable_image_debugging',
action='store_true',
help='Generates debugging image to display'
)
render_group.add_argument(
'--karras_max',
type=int,
default=None,
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
)
# Restoration related args
postprocessing_group.add_argument(
'--no_restore',
dest='restore',
action='store_false',
help='Disable face restoration with GFPGAN or codeformer',
)
postprocessing_group.add_argument(
'--no_upscale',
dest='esrgan',
action='store_false',
help='Disable upscaling with ESRGAN',
)
postprocessing_group.add_argument(
'--esrgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
)
postprocessing_group.add_argument(
'--esrgan_denoise_str',
type=float,
default=0.75,
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
)
postprocessing_group.add_argument(
'--gfpgan_model_path',
type=str,
default='./models/gfpgan/GFPGANv1.4.pth',
help='Indicates the path to the GFPGAN model',
)
web_server_group.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
)
web_server_group.add_argument(
'--web_develop',
dest='web_develop',
action='store_true',
help='Start in web server development mode.',
)
web_server_group.add_argument(
"--web_verbose",
action="store_true",
help="Enables verbose logging",
)
web_server_group.add_argument(
"--cors",
nargs="*",
type=str,
help="Additional allowed origins, comma-separated",
)
web_server_group.add_argument(
'--host',
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
)
web_server_group.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
web_server_group.add_argument(
'--certfile',
type=str,
default=None,
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
)
web_server_group.add_argument(
'--keyfile',
type=str,
default=None,
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
)
web_server_group.add_argument(
'--gui',
dest='gui',
action='store_true',
help='Start InvokeAI GUI',
)

View File

@@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import global_cache_dir, global_config_dir
from .model_manager import ModelManager, SDLegacyType
@@ -74,6 +74,7 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -842,7 +843,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
)
keys = list(checkpoint.keys())
@@ -897,7 +898,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = InvokeAIAppConfig.get_config().cache_dir
cache_dir = global_cache_dir("hub")
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
@@ -969,7 +970,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
cache_dir = InvokeAIAppConfig.get_config().cache_dir
cache_dir = global_cache_dir("hub")
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
)
@@ -1092,8 +1093,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
config = InvokeAIAppConfig.get_config()
cache_dir = config.cache_dir
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@@ -1107,6 +1106,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
@@ -1130,23 +1130,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if model_type == SDLegacyType.V2_v:
original_config_file = (
config.legacy_conf_path / "v2-inference-v.yaml"
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
config.legacy_conf_path / "v2-inference.yaml"
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
config.legacy_conf_path / "v1-inpainting-inference.yaml"
global_config_dir()
/ "stable-diffusion"
/ "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
config.legacy_conf_path / "v1-inference.yaml"
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
)
else:
@@ -1298,7 +1300,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=cache_dir,
cache_dir=global_cache_dir("hub"),
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir

View File

@@ -11,16 +11,14 @@ import gc
import hashlib
import os
import re
import shutil
import sys
import textwrap
import time
import traceback
import warnings
from enum import Enum, auto
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union, Callable, Dict, List, types
from typing import Any, Optional, Union, Callable, types
import safetensors
import safetensors.torch
@@ -38,6 +36,8 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from transformers import (
CLIPTextModel,
CLIPTokenizer,
@@ -49,13 +49,9 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from invokeai.app.services.config import InvokeAIAppConfig
from ..install.model_install_backend import (
Dataset_path,
hf_download_with_resume,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
@@ -104,7 +100,6 @@ class ModelManager(object):
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.globals = InvokeAIAppConfig.get_config()
self.precision = precision
self.device = torch.device(device_type)
self.max_loaded_models = max_loaded_models
@@ -297,7 +292,7 @@ class ModelManager(object):
"""
# if we are converting legacy files automatically, then
# there are no legacy ckpts!
if self.globals.ckpt_convert:
if Globals.ckpt_convert:
return False
info = self.model_info(model_name)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
@@ -320,7 +315,7 @@ class ModelManager(object):
models = {}
for name in sorted(self.config, key=str.casefold):
stanza = self.config[name]
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
@@ -507,13 +502,13 @@ class ModelManager(object):
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
safety_checker=None, local_files_only=not self.globals.internet_available
safety_checker=None, local_files_only=not Globals.internet_available
)
if "vae" in mconfig and mconfig["vae"] is not None:
if vae := self._load_vae(mconfig["vae"]):
pipeline_args.update(vae=vae)
if not isinstance(name_or_path, Path):
pipeline_args.update(cache_dir=self.globals.cache_dir)
pipeline_args.update(cache_dir=global_cache_dir("hub"))
if using_fp16:
pipeline_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
@@ -532,8 +527,7 @@ class ModelManager(object):
**fp_args,
)
except OSError as e:
if str(e).startswith("fp16 is not a valid") or \
'Invalid rev id: fp16' in str(e):
if str(e).startswith("fp16 is not a valid"):
pass
else:
self.logger.error(
@@ -566,9 +560,10 @@ class ModelManager(object):
width = mconfig.width
height = mconfig.height
root_dir = self.globals.root_dir
config = str(root_dir / config)
weights = str(root_dir / weights)
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
@@ -583,7 +578,11 @@ class ModelManager(object):
vae_path = None
if vae:
vae_path = str(root_dir / vae)
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -615,7 +614,9 @@ class ModelManager(object):
)
if "path" in mconfig and mconfig["path"] is not None:
path = self.globals.root_dir / Path(mconfig["path"])
path = Path(mconfig["path"])
if not path.is_absolute():
path = Path(Globals.root, path).resolve()
return path
elif "repo_id" in mconfig:
return mconfig["repo_id"]
@@ -780,11 +781,11 @@ class ModelManager(object):
"""
model_path: Path = None
thing = str(path_url_or_repo) # to save typing
thing = path_url_or_repo # to save typing
self.logger.info(f"Probing {thing} for import")
if str(thing).startswith(("http:", "https:", "ftp:")):
if thing.startswith(("http:", "https:", "ftp:")):
self.logger.info(f"{thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
@@ -820,9 +821,7 @@ class ModelManager(object):
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m),
commit_to_conf=commit_to_conf,
config_file_callback=config_file_callback,
str(m), commit_to_conf=commit_to_conf
):
self.logger.info(f"{model_name} successfully imported")
return model_name
@@ -865,24 +864,35 @@ class ModelManager(object):
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
self.logger.debug("SD-v1 model detected")
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml"
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
elif model_type == SDLegacyType.V2_v:
self.logger.debug("SD-v2-v model detected")
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
self.logger.debug("SD-v2-e model detected")
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
@@ -899,7 +909,9 @@ class ModelManager(object):
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = self.globals.root_dir / "models/converted_ckpts" / model_path.stem
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
@@ -939,35 +951,34 @@ class ModelManager(object):
from . import convert_ckpt_to_diffusers
if diffusers_path.exists():
self.logger.error(
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
)
return
model_name = model_name or diffusers_path.name
model_description = model_description or f"Converted version of {model_name}"
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
try:
if diffusers_path.exists():
self.logger.error(
f"The path {str(diffusers_path)} already exists. Installing previously-converted path."
)
else:
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
if vae:
vae_model = self._load_vae(vae)
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}"
)
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
if vae:
vae_model = self._load_vae(vae)
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}"
)
self.logger.debug(f"Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
@@ -979,10 +990,9 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
self.logger.debug(f"Model {model_name} installed")
self.logger.debug("Conversion succeeded")
except Exception as e:
self.logger.warning(f"Conversion failed: {str(e)}")
self.logger.warning(traceback.format_exc())
self.logger.warning(
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
)
@@ -1034,7 +1044,9 @@ class ModelManager(object):
"""
yaml_str = OmegaConf.to_yaml(self.config)
if not os.path.isabs(config_file_path):
config_file_path = self.globals.model_conf_path
config_file_path = os.path.normpath(
os.path.join(Globals.root, config_file_path)
)
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble())
@@ -1066,8 +1078,7 @@ class ModelManager(object):
"""
# Three transformer models to check: bert, clip and safety checker, and
# the diffusers as well
config = InvokeAIAppConfig.get_config()
models_dir = config.root_dir / "models"
models_dir = Path(Globals.root, "models")
legacy_locations = [
Path(
models_dir,
@@ -1079,8 +1090,8 @@ class ModelManager(object):
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
),
]
legacy_cache_dir = config.cache_dir / "../diffusers"
legacy_locations.extend(list(legacy_cache_dir.glob("*")))
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@@ -1102,7 +1113,7 @@ class ModelManager(object):
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
hub = config.cache_dir
hub = global_cache_dir("hub")
else:
hub = models_dir / "hub"
@@ -1141,12 +1152,13 @@ class ModelManager(object):
if str(source).startswith(("http:", "https:", "ftp:")):
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = self.globals.root_dir / dest_directory
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
source = self.globals.root_dir / source
resolved_path = source
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
resolved_path = Path(source)
return resolved_path
def _invalidate_cached_model(self, model_name: str) -> None:
@@ -1196,7 +1208,7 @@ class ModelManager(object):
path = name_or_path
else:
owner, repo = name_or_path.split("/")
path = self.globals.cache_dir / f"models--{owner}--{repo}"
path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}")
if not path.exists():
return None
hashpath = path / "checksum.sha256"
@@ -1257,8 +1269,8 @@ class ModelManager(object):
using_fp16 = self.precision == "float16"
vae_args.update(
cache_dir=self.globals.cache_dir,
local_files_only=not self.globals.internet_available,
cache_dir=global_cache_dir("hub"),
local_files_only=not Globals.internet_available,
)
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
@@ -1296,7 +1308,7 @@ class ModelManager(object):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
cache_info = scan_cache_dir(global_cache_dir("hub"))
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
@@ -1313,195 +1325,12 @@ class ModelManager(object):
@staticmethod
def _abs_path(path: str | Path) -> Path:
globals = InvokeAIAppConfig.get_config()
if path is None or Path(path).is_absolute():
return path
return Path(globals.root_dir, path).resolve()
return Path(Globals.root, path).resolve()
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return (
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
)
def list_lora_models(self)->Dict[str,bool]:
'''Return a dict of installed lora models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('lora') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.lora_path
installed_models = dict()
for root, dirs, files in os.walk(dir):
for name in files:
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
continue
if name == 'pytorch_lora_weights.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_lora_models(self, model_names: list[str], access_token:str=None):
'''Download list of LoRA/LyCORIS models'''
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
for name in model_names:
name = short_names.get(name) or name
# HuggingFace style LoRA
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.lora_path / dest_dir,
model_name = 'pytorch_lora_weights.bin',
access_token = access_token,
)
elif name.startswith(("http:", "https:", "ftp:")):
download_with_resume(name, self.globals.lora_path)
else:
self.logger.error(f"Unknown repo_id or URL: {name}")
def delete_lora_models(self, model_names: List[str]):
'''Remove the list of lora models'''
for name in model_names:
file_or_directory = self.globals.lora_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging LoRA/LyCORIS {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.lora_path.glob(f'{name}.*'):
self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink()
def list_ti_models(self)->Dict[str,bool]:
'''Return a dict of installed textual models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.embedding_path
for root, dirs, files in os.walk(dir):
for name in files:
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
continue
if name == 'learned_embeds.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_ti_models(self, model_names: list[str], access_token: str=None):
'''Download list of textual inversion embeddings'''
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
for name in model_names:
name = short_names.get(name) or name
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading Textual Inversion embedding {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.embedding_path / dest_dir,
model_name = 'learned_embeds.bin',
access_token = access_token
)
elif name.startswith(('http:','https:','ftp:')):
download_with_resume(name, self.globals.embedding_path)
else:
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
def delete_ti_models(self, model_names: list[str]):
'''Remove TI embeddings from disk'''
for name in model_names:
file_or_directory = self.globals.embedding_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging textual inversion embedding {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.embedding_path.glob(f'{name}.*'):
self.logger.info(f'Purging textual inversion embedding {name}')
path.unlink()
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and value is True if installed'''
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
installed_models = {x: False for x in cn_models.keys()}
cn_dir = self.globals.controlnet_path
for root, dirs, files in os.walk(cn_dir):
for name in dirs:
if Path(root, name, '.download_complete').exists():
installed_models.update({name.replace('--','/'): True})
return installed_models
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
dest_dir = self.globals.controlnet_path
dest_dir.mkdir(parents=True,exist_ok=True)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if self.precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in model_names:
repo_id = short_names.get(directory_name) or directory_name
safe_name = directory_name.replace('/','--')
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
def delete_controlnet_models(self, model_names: List[str]):
'''Remove the list of controlnet models'''
for name in model_names:
safe_name = name.replace('/','--')
directory = self.globals.controlnet_path / safe_name
if directory.exists():
self.logger.info(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))

View File

@@ -20,12 +20,11 @@ from compel.prompt_parser import (
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
config = InvokeAIAppConfig.get_config()
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
@@ -39,7 +38,7 @@ def get_uc_and_c_and_ec(prompt_string,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
@@ -57,7 +56,7 @@ def get_uc_and_c_and_ec(prompt_string,
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or config.log_tokenization:
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
@@ -282,8 +281,6 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text)
]
if len(parsed_prompts) == 0:
return []
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..globals import Globals
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@@ -17,11 +17,11 @@ class CodeFormerRestoration:
def __init__(
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
@@ -71,7 +71,9 @@ class CodeFormerRestoration:
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
model_rootpath=os.path.join(
Globals.root, "models", "gfpgan", "weights"
),
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)

View File

@@ -7,13 +7,14 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import Globals
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
gfpgan_model_path = os.path.abspath(
os.path.join(Globals.root, gfpgan_model_path)
)
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
@@ -32,7 +33,7 @@ class GFPGAN:
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(self.globals.root_dir / 'models')
os.chdir(os.path.join(Globals.root, "models"))
try:
from gfpgan import GFPGANer

View File

@@ -1,3 +1,4 @@
import os
import warnings
import numpy as np
@@ -6,8 +7,7 @@ from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
from invokeai.backend.globals import Globals
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
@@ -30,8 +30,12 @@ class ESRGAN:
upscale=4,
act_type="prelu",
)
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
)
wdn_model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
scale = 4
bg_upsampler = RealESRGANer(

View File

@@ -15,11 +15,9 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .globals import global_cache_dir
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
@@ -28,10 +26,10 @@ class SafetyChecker(object):
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = config.cache_dir
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,

View File

@@ -17,17 +17,16 @@ from huggingface_hub import (
hf_hub_url,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
"""
Initialize the Concepts object. May optionally pass a root directory.
"""
self.config = InvokeAIAppConfig.get_config()
self.root = root or self.config.root
self.root = root or Globals.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
@@ -59,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
elif self.config.internet_available is True:
elif Globals.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")

View File

@@ -2,29 +2,23 @@ from __future__ import annotations
import dataclasses
import inspect
import math
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
import einops
import PIL.Image
import numpy as np
from accelerate.utils import set_seed
import psutil
import torch
import torchvision.transforms as T
from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
)
@@ -33,14 +27,14 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import Globals
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
@@ -50,6 +44,7 @@ from .diffusion import (
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@dataclass
class PipelineIntermediateState:
run_id: str
@@ -75,10 +70,10 @@ class AddsMaskLatents:
initial_image_latents: torch.Tensor
def __call__(
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
return self.forward(model_input, t, text_embeddings)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
@@ -214,19 +209,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True)
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
guidance_scale: Union[float, List[float]]
guidance_scale: float
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@@ -316,7 +304,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None,
):
super().__init__(
vae,
@@ -337,8 +324,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True
@@ -358,17 +343,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = InvokeAIAppConfig.get_config()
if (
torch.cuda.is_available()
and is_xformers_available()
and not config.disable_xformers
and not Globals.disable_xformers
):
self.enable_xformers_memory_efficient_attention()
else:
@@ -481,7 +464,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@@ -502,7 +484,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise,
run_id=run_id,
callback=callback,
**kwargs,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@@ -527,8 +508,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
@@ -549,8 +528,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
control_data=control_data,
**kwargs,
)
return result.latents, result.attention_map_saver
@@ -563,8 +540,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
@@ -594,7 +569,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(
@@ -604,8 +579,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
)
latents = step_output.prev_sample
@@ -646,11 +619,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
@@ -658,56 +630,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,
)
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,
@@ -717,8 +639,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
)
# compute the previous noisy sample x_t -> x_t-1
@@ -740,7 +660,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t,
text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
@@ -760,8 +679,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
).sample
def img2img_from_embeddings(
@@ -811,7 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
@@ -1023,51 +941,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
def prepare_control_image(
self,
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image

View File

@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Any, Callable, Dict, Optional, Union, List
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
@@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import Globals
from .cross_attention_control import (
Arguments,
@@ -32,6 +32,7 @@ ModelForwardCallback: TypeAlias = Union[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@@ -72,13 +73,12 @@ class InvokeAIDiffuserComponent:
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
config = InvokeAIAppConfig.get_config()
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager
@@ -112,25 +112,23 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
# apparently unused code
# TODO: delete
# def override_cross_attention(
# self, conditioning: ExtraConditioningInfo, step_count: int
# ) -> Dict[str, AttentionProcessor]:
# """
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
# the previous attention processor is returned so that the caller can restore it later.
# """
# self.conditioning = conditioning
# self.cross_attention_control_context = Context(
# arguments=self.conditioning.cross_attention_control_args,
# step_count=step_count,
# )
# return override_cross_attention(
# self.model,
# self.cross_attention_control_context,
# is_running_diffusers=self.is_running_diffusers,
# )
def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttentionProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count,
)
return override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None
@@ -180,11 +178,9 @@ class InvokeAIDiffuserComponent:
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
unconditional_guidance_scale: float,
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
**kwargs,
):
"""
:param x: current latents
@@ -196,11 +192,6 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@@ -218,7 +209,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning
)
elif wants_cross_attention_control:
(
@@ -230,14 +221,13 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning
)
else:
@@ -245,12 +235,11 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning
)
combined_next_x = self._combine(
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
)
return combined_next_x
@@ -293,13 +282,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice, sigma_twice, both_conditionings
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
@@ -313,17 +302,16 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
**kwargs,
):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
@@ -338,7 +326,7 @@ class InvokeAIDiffuserComponent:
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
x_twice, sigma_twice, both_conditionings
).chunk(2)
return unconditioned_next_x, conditioned_next_x
@@ -349,7 +337,6 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(
@@ -358,7 +345,6 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
)
else:
return self._apply_cross_attention_controlled_conditioning__compvis(
@@ -367,7 +353,6 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
)
def _apply_cross_attention_controlled_conditioning__diffusers(
@@ -377,7 +362,6 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
context: Context = self.cross_attention_control_context
@@ -393,7 +377,6 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
)
# do requested cross attention types for conditioning (positive prompt)
@@ -405,7 +388,6 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
)
return unconditioned_next_x, conditioned_next_x
@@ -416,7 +398,6 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
@@ -430,13 +411,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
_ = self.model_forward_callback(x, sigma, conditioning)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@@ -447,7 +428,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning
)
conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning, **kwargs,
x, sigma, edited_conditioning
)
context.clear_requests(cleanup=True)
@@ -504,7 +485,7 @@ class InvokeAIDiffuserComponent:
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)

View File

@@ -7,6 +7,7 @@
This is the backend to "textual_inversion.py"
"""
import argparse
import logging
import math
import os
@@ -46,7 +47,8 @@ from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
from ..args import ArgFormatter, PagingArgumentParser
from ..globals import Globals, global_cache_dir
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@@ -88,9 +90,8 @@ def save_progress(
def parse_args():
config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser(
description="Textual inversion training"
description="Textual inversion training", formatter_class=ArgFormatter
)
general_group = parser.add_argument_group("General")
model_group = parser.add_argument_group("Models and Paths")
@@ -111,7 +112,7 @@ def parse_args():
"--root_dir",
"--root",
type=Path,
default=config.root,
default=Globals.root,
help="Path to the invokeai runtime directory",
)
general_group.add_argument(
@@ -126,7 +127,7 @@ def parse_args():
general_group.add_argument(
"--output_dir",
type=Path,
default=f"{config.root}/text-inversion-model",
default=f"{Globals.root}/text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
model_group.add_argument(
@@ -527,7 +528,6 @@ def get_full_repo_name(
def do_textual_inversion_training(
config: InvokeAIAppConfig,
model: str,
train_data_dir: Path,
output_dir: Path,
@@ -580,7 +580,7 @@ def do_textual_inversion_training(
# setting up things the way invokeai expects them
if not os.path.isabs(output_dir):
output_dir = os.path.join(config.root, output_dir)
output_dir = os.path.join(Globals.root, output_dir)
logging_dir = output_dir / logging_dir
@@ -628,7 +628,7 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
models_conf = OmegaConf.load(config.model_conf_path)
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
model_conf = models_conf.get(model, None)
assert model_conf is not None, f"Unknown model: {model}"
assert (
@@ -640,7 +640,7 @@ def do_textual_inversion_training(
assert (
pretrained_model_name_or_path
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
pipeline_args = dict(cache_dir=config.cache_dir)
pipeline_args = dict(cache_dir=global_cache_dir("hub"))
# Load tokenizer
if tokenizer_name:

View File

@@ -17,5 +17,3 @@ from .util import (
instantiate_from_config,
url_attachment_name,
)

View File

@@ -4,16 +4,17 @@ from contextlib import nullcontext
import torch
from torch import autocast
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.globals import Globals
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if config.always_use_cpu:
if Globals.always_use_cpu:
return CPU_DEVICE
if torch.cuda.is_available():
return torch.device("cuda")
@@ -32,7 +33,7 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype:
if config.full_precision:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.float16

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
"""invokeai.util.logging
Logging class for InvokeAI that produces console messages
@@ -12,7 +11,6 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.configure()
logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message
@@ -30,165 +28,10 @@ Console messages:
Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message')
## Configuration
The default configuration will print to stderr on the console. To add
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
object:
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger(config=config)
### Three command-line options control logging:
`--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
```
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```
"""
import logging.handlers
import socket
import urllib.parse
import logging
from abc import abstractmethod
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
try:
import syslog
SYSLOG_AVAILABLE = True
except:
SYSLOG_AVAILABLE = False
# module level functions
def debug(msg, *args, **kwargs):
@@ -219,77 +62,11 @@ def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
_FACILITY_MAP = dict(
LOG_KERN = syslog.LOG_KERN,
LOG_USER = syslog.LOG_USER,
LOG_MAIL = syslog.LOG_MAIL,
LOG_DAEMON = syslog.LOG_DAEMON,
LOG_AUTH = syslog.LOG_AUTH,
LOG_LPR = syslog.LOG_LPR,
LOG_NEWS = syslog.LOG_NEWS,
LOG_UUCP = syslog.LOG_UUCP,
LOG_CRON = syslog.LOG_CRON,
LOG_SYSLOG = syslog.LOG_SYSLOG,
LOG_LOCAL0 = syslog.LOG_LOCAL0,
LOG_LOCAL1 = syslog.LOG_LOCAL1,
LOG_LOCAL2 = syslog.LOG_LOCAL2,
LOG_LOCAL3 = syslog.LOG_LOCAL3,
LOG_LOCAL4 = syslog.LOG_LOCAL4,
LOG_LOCAL5 = syslog.LOG_LOCAL5,
LOG_LOCAL6 = syslog.LOG_LOCAL6,
LOG_LOCAL7 = syslog.LOG_LOCAL7,
) if SYSLOG_AVAILABLE else dict()
_SOCK_MAP = dict(
SOCK_STREAM = socket.SOCK_STREAM,
SOCK_DGRAM = socket.SOCK_DGRAM,
)
class InvokeAIFormatter(logging.Formatter):
'''
Base class for logging formatter
'''
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record)
@abstractmethod
def log_fmt(self, levelno: int)->str:
pass
class InvokeAISyslogFormatter(InvokeAIFormatter):
'''
Formatting for syslog
'''
def log_fmt(self, levelno: int)->str:
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
'''
Formatting for the InvokeAI Logger (legacy version)
'''
FORMATS = {
logging.DEBUG: " | %(message)s",
logging.INFO: ">> %(message)s",
logging.WARNING: "** %(message)s",
logging.ERROR: "*** %(message)s",
logging.CRITICAL: "### %(message)s",
}
def log_fmt(self,levelno:int)->str:
return self.FORMATS.get(levelno)
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
'''
Custom Formatting for the InvokeAI Logger (plain version)
'''
def log_fmt(self, levelno: int)->str:
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter):
class InvokeAILogFormatter(logging.Formatter):
'''
Custom Formatting for the InvokeAI Logger
'''
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
@@ -299,127 +76,35 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
reset = "\x1b[0m"
# Log Format
log_format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
# Format Map
FORMATS = {
logging.DEBUG: cyan + log_format + reset,
logging.INFO: grey + log_format + reset,
logging.WARNING: yellow + log_format + reset,
logging.ERROR: red + log_format + reset,
logging.CRITICAL: bold_red + log_format + reset
logging.DEBUG: cyan + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def log_fmt(self, levelno: int)->str:
return self.FORMATS.get(levelno)
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
return formatter.format(record)
LOG_FORMATTERS = {
'plain': InvokeAIPlainLogFormatter,
'color': InvokeAIColorLogFormatter,
'syslog': InvokeAISyslogFormatter,
'legacy': InvokeAILegacyLogFormatter,
}
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(cls,
name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
if name not in self.loggers:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
cls.loggers[name] = logger
return cls.loggers[name]
@classmethod
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None
# console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console':
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='syslog':
ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='file':
ch = cls._parse_file_args(args)
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='http':
handlers.append(cls._parse_http_args(args))
return handlers
@staticmethod
def _parse_syslog_args(
args: str=None
)-> logging.Handler:
if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system")
if not args:
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
syslog_args = dict()
try:
for a in args.split(','):
arg_name,*arg_value = a.split(':',2)
if arg_name=='address':
host,*port = arg_value
port = 514 if len(port)==0 else int(port[0])
syslog_args['address'] = (host,port)
elif arg_name=='facility':
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
elif arg_name=='socktype':
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
else:
syslog_args['address'] = arg_name
except:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod
def _parse_file_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args)
@staticmethod
def _parse_http_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(',')
url = urllib.parse.urlparse(arg_list.pop(0))
if url.scheme != 'http':
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
host = url.hostname
path = url.path
port = url.port or 80
syslog_args = dict()
for a in arg_list:
arg_name, *arg_value = a.split(':',2)
if arg_name=='method':
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
syslog_args[arg_name] = arg_value
else: # TODO: Provide support for SSL context and credentials
pass
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
self.loggers[name] = logger
return self.loggers[name]

View File

@@ -322,8 +322,8 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
if resp.status_code == 416 or exist_size == content_length:
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
@@ -331,7 +331,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
elif resp.status_code != 200:
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
else:
logger.info(f"{dest}: Downloading...")
logger.error(f"{dest}: Downloading...")
try:
if content_length < 2000:

View File

@@ -1,107 +1,83 @@
# This file predefines a few models that the user may want to install.
diffusers:
stable-diffusion-1.5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True
default: True
sd-inpainting-1.5:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True
stable-diffusion-2.1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-1
format: diffusers
recommended: True
sd-inpainting-2.0:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting
format: diffusers
recommended: False
analog-diffusion-1.0:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion
format: diffusers
recommended: false
deliberate-1.0:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
format: diffusers
repo_id: XpucT/Deliberate
recommended: False
d&d-diffusion-1.0:
description: Dungeons & Dragons characters (2.13 GB)
format: diffusers
repo_id: 0xJustin/Dungeons-and-Diffusion
recommended: False
dreamlike-photoreal-2.0:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
format: diffusers
repo_id: dreamlike-art/dreamlike-photoreal-2.0
recommended: False
inkpunk-1.0:
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
format: diffusers
repo_id: Envvi/Inkpunk-Diffusion
recommended: False
openjourney-4.0:
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
format: diffusers
repo_id: prompthero/openjourney
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
portrait-plus-1.0:
description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)
format: diffusers
repo_id: wavymulder/portraitplus
recommended: False
seek-art-mega-1.0:
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
repo_id: coreco/seek.art_MEGA
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
trinart-2.0:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
repo_id: naclbit/trinart_stable_diffusion_v2
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
waifu-diffusion-1.4:
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
repo_id: hakurei/waifu-diffusion
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
controlnet:
canny: lllyasviel/control_v11p_sd15_canny
inpaint: lllyasviel/control_v11p_sd15_inpaint
mlsd: lllyasviel/control_v11p_sd15_mlsd
depth: lllyasviel/control_v11f1p_sd15_depth
normal_bae: lllyasviel/control_v11p_sd15_normalbae
seg: lllyasviel/control_v11p_sd15_seg
lineart: lllyasviel/control_v11p_sd15_lineart
lineart_anime: lllyasviel/control_v11p_sd15s2_lineart_anime
scribble: lllyasviel/control_v11p_sd15_scribble
softedge: lllyasviel/control_v11p_sd15_softedge
shuffle: lllyasviel/control_v11e_sd15_shuffle
tile: lllyasviel/control_v11f1e_sd15_tile
ip2p: lllyasviel/control_v11e_sd15_ip2p
textual_inversion:
'EasyNegative': https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
'ahx-beta-453407d': sd-concepts-library/ahx-beta-453407d
lora:
'LowRA': https://civitai.com/api/download/models/63006
'Ink scenery': https://civitai.com/api/download/models/83390
'sd-model-finetuned-lora-t4': sayakpaul/sd-model-finetuned-lora-t4
stable-diffusion-1.5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True
default: True
sd-inpainting-1.5:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: True
stable-diffusion-2.1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-1
format: diffusers
recommended: True
sd-inpainting-2.0:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting
format: diffusers
recommended: False
analog-diffusion-1.0:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion
format: diffusers
recommended: false
deliberate-1.0:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
format: diffusers
repo_id: XpucT/Deliberate
recommended: False
d&d-diffusion-1.0:
description: Dungeons & Dragons characters (2.13 GB)
format: diffusers
repo_id: 0xJustin/Dungeons-and-Diffusion
recommended: False
dreamlike-photoreal-2.0:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
format: diffusers
repo_id: dreamlike-art/dreamlike-photoreal-2.0
recommended: False
inkpunk-1.0:
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
format: diffusers
repo_id: Envvi/Inkpunk-Diffusion
recommended: False
openjourney-4.0:
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
format: diffusers
repo_id: prompthero/openjourney
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
portrait-plus-1.0:
description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)
format: diffusers
repo_id: wavymulder/portraitplus
recommended: False
seek-art-mega-1.0:
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
repo_id: coreco/seek.art_MEGA
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
trinart-2.0:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
repo_id: naclbit/trinart_stable_diffusion_v2
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
waifu-diffusion-1.4:
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
repo_id: hakurei/waifu-diffusion
format: diffusers
vae:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False

1291
invokeai/frontend/CLI/CLI.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,497 @@
"""
Readline helper functions for invoke.py.
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from invokeai.frontend.CLI.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
"""
import atexit
import os
import re
from ...backend.args import Args
from ...backend.globals import Globals
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except (ImportError, ModuleNotFoundError) as e:
print(f"** An error occurred when loading the readline module: {str(e)}")
readline_available = False
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
TEXT_EXTENSIONS = (".txt", ".TXT")
CONFIG_EXTENSIONS = (".yaml", ".yml")
COMMANDS = (
"--steps",
"-s",
"--seed",
"-S",
"--iterations",
"-n",
"--width",
"-W",
"--height",
"-H",
"--cfg_scale",
"-C",
"--threshold",
"--perlin",
"--grid",
"-g",
"--individual",
"-i",
"--save_intermediates",
"--init_img",
"-I",
"--init_mask",
"-M",
"--init_color",
"--strength",
"-f",
"--variants",
"-v",
"--outdir",
"-o",
"--sampler",
"-A",
"-m",
"--embedding_path",
"--device",
"--grid",
"-g",
"--facetool",
"-ft",
"--facetool_strength",
"-G",
"--codeformer_fidelity",
"-cf",
"--upscale",
"-U",
"-save_orig",
"--save_original",
"--log_tokenization",
"-t",
"--hires_fix",
"--inpaint_replace",
"-r",
"--png_compression",
"-z",
"--text_mask",
"-tm",
"--h_symmetry_time_pct",
"--v_symmetry_time_pct",
"!fix",
"!fetch",
"!replay",
"!history",
"!search",
"!clear",
"!models",
"!switch",
"!import_model",
"!optimize_model",
"!convert_model",
"!edit_model",
"!del_model",
"!mask",
"!triggers",
)
MODEL_COMMANDS = (
"!switch",
"!edit_model",
"!del_model",
)
CKPT_MODEL_COMMANDS = ("!optimize_model",)
WEIGHT_COMMANDS = (
"!import_model",
"!convert_model",
)
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
TEXT_PATH_COMMANDS = ("!replay",)
IMG_FILE_COMMANDS = (
"!fix",
"!fetch",
"!mask",
"--init_img[=\s]",
"-I",
"--init_mask[=\s]",
"-M",
"--init_color[=\s]",
"--embedding_path[=\s]",
)
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
class Completer(object):
def __init__(self, options, models={}):
self.options = sorted(options)
self.models = models
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
self.auto_history_active = True
self.extensions = None
self.concepts = None
self.embedding_terms = set()
return
def complete(self, text, state):
"""
Completes invoke command line.
BUG: it doesn't correctly complete files that have spaces in the name.
"""
buffer = readline.get_line_buffer()
if state == 0:
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp, buffer):
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
self.matches = self._path_completions(
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
)
# looking for a seed
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
self.matches = self._seed_completions(text, state)
# looking for an embedding concept
elif re.search("<[\w-]*$", buffer):
self.matches = self._concept_completions(text, state)
# looking for a model
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state)
# looking for a ckpt model
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state, ckpt_only=True)
elif re.search(weight_regexp, buffer):
self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp, buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [s for s in self.options if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def complete_extensions(self, extensions: list):
"""
If called with a list of extensions, will force completer
to do file path completions.
"""
self.extensions = extensions
def add_history(self, line):
"""
Pass thru to readline
"""
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
"""
Pass clear_history() thru to readline
"""
readline.clear_history()
def search_history(self, match: str):
"""
Like show_history() but only shows items that
contain the match string.
"""
self.show_history(match)
def remove_history_item(self, pos):
readline.remove_history_item(pos)
def add_seed(self, seed):
"""
Add a seed to the autocomplete list for display when -S is autocompleted.
"""
if seed is not None:
self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir = path
def set_options(self, options):
self.options = options
def get_line(self, index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self, index):
return readline.get_history_item(index)
def show_history(self, match=None):
"""
Print the session history using the pydoc pager
"""
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print("<empty history>")
return
for i in range(0, h_len):
line = self.get_history_item(i + 1)
if match and match not in line:
continue
lines.append(f"[{i+1}] {line}")
pydoc.pager("\n".join(lines))
def set_line(self, line) -> None:
"""
Set the default string displayed in the next line of input.
"""
self.linebuffer = line
readline.redisplay()
def update_models(self, models: dict) -> None:
"""
update our list of models
"""
self.models = models
def _seed_completions(self, text, state):
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def add_embedding_terms(self, terms: list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(set(self.concepts.list_concepts()))
def _concept_completions(self, text, state):
if self.concepts is None:
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = HuggingFaceConceptsLibrary()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))
partial = text[1:] # this removes the leading '<'
if len(partial) == 0:
return list(self.embedding_terms) # whole dump - think if user wants this!
matches = list()
for concept in self.embedding_terms:
if concept.startswith(partial):
matches.append(f"<{concept}>")
matches.sort()
return matches
def _model_completions(self, text, state, ckpt_only=False):
m = re.search("(!switch\s+)(\w*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.models:
format = self.models[s]["format"]
if format == "vae":
continue
if ckpt_only and format != "ckpt":
continue
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
):
# separate the switch from the partial path
match = re.search("^(-\w|--\w+=?)(.*)", text)
if match is None:
switch = None
partial_path = text
else:
switch, partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != "":
dir = os.path.dirname(path)
else:
dir = default_dir if os.path.exists(default_dir) else ""
path = os.path.join(dir, path)
dir_list = os.listdir(dir or ".")
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
dir_list += os.listdir(self.default_dir)
for node in dir_list:
if node.startswith(".") and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if path and not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir, node)
matches.append(
match_path + "/" if os.path.isdir(full_path) else match_path
)
elif os.path.isdir(full_path):
matches.append(
switch + os.path.join(os.path.dirname(full_path), node) + "/"
)
elif node.endswith(extensions):
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
return matches
class DummyCompleter(Completer):
def __init__(self, options):
super().__init__(options)
self.history = list()
def add_history(self, line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)
def get_history_item(self, index):
return self.history[index - 1]
def remove_history_item(self, index):
return self.history.pop(index - 1)
def set_line(self, line):
print(f"# {line}")
def generic_completer(commands: list) -> Completer:
if readline_available:
completer = Completer(commands, [])
readline.set_completer(completer.complete)
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
else:
completer = DummyCompleter(commands)
return completer
def get_completer(opt: Args, models=[]) -> Completer:
if readline_available:
completer = Completer(COMMANDS, models)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(False)
completer.auto_history_active = False
except:
completer.auto_history_active = True
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
outdir = os.path.expanduser(opt.outdir)
if os.path.isabs(outdir):
histfile = os.path.join(outdir, ".invoke_history")
else:
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
os.replace(histfile, newname)
atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)
return completer

View File

@@ -0,0 +1,30 @@
'''
This is a modularized version of the sd-metadata.py script,
which retrieves and prints the metadata from a series of generated png files.
'''
import sys
import json
from invokeai.backend.image_util import retrieve_metadata
def print_metadata():
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
exit(-1)
filenames = sys.argv[1:]
for f in filenames:
try:
metadata = retrieve_metadata(f)
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
continue
if __name__== '__main__':
print_metadata()

View File

@@ -1,4 +1,4 @@
"""
Wrapper for invokeai.backend.configure.invokeai_configure
"""
from ...backend.install.invokeai_configure import main
from ...backend.config.invokeai_configure import main

View File

@@ -4,14 +4,14 @@ pip install <path_to_git_source>.
'''
import os
import platform
import pkg_resources
import psutil
import requests
from rich import box, print
from rich.console import Console, group
from rich.console import Console, Group, group
from rich.panel import Panel
from rich.prompt import Prompt
from rich.style import Style
from rich.syntax import Syntax
from rich.text import Text
from invokeai.version import __version__
@@ -32,18 +32,6 @@ else:
def get_versions()->dict:
return requests.get(url=INVOKE_AI_REL).json()
def invokeai_is_running()->bool:
for p in psutil.process_iter():
try:
cmdline = p.cmdline()
matches = [x for x in cmdline if x.endswith(('invokeai','invokeai.exe'))]
if matches:
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
return True
except (psutil.AccessDenied,psutil.NoSuchProcess):
continue
return False
def welcome(versions: dict):
@group()
@@ -72,22 +60,8 @@ def welcome(versions: dict):
)
console.line()
def get_extras():
extras = ''
try:
dist = pkg_resources.get_distribution('xformers')
extras = '[xformers]'
except pkg_resources.DistributionNotFound:
pass
return extras
def main():
versions = get_versions()
if invokeai_is_running():
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
input('Press any key to continue...')
return
welcome(versions)
tag = None
@@ -104,15 +78,13 @@ def main():
elif choice=='4':
branch = Prompt.ask('Enter an InvokeAI branch name')
extras = get_extras()
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
if release:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip' --use-pep517 --upgrade"
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
elif tag:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip' --use-pep517 --upgrade"
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
else:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip' --use-pep517 --upgrade"
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
print('')
print('')
if os.system(cmd)==0:

File diff suppressed because it is too large Load Diff

View File

@@ -5,75 +5,35 @@ import curses
import math
import os
import platform
import pyperclip
import struct
import subprocess
import sys
import npyscreen
import textwrap
import npyscreen.wgmultiline as wgmultiline
from npyscreen import fmPopup
from shutil import get_terminal_size
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs
MIN_COLS = 120
MIN_LINES = 50
import npyscreen
# -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
ts = get_terminal_size()
width = max(columns,ts.columns)
height = max(lines,ts.lines)
def set_terminal_size(columns: int, lines: int):
OS = platform.uname().system
if OS == "Windows":
# The new Windows Terminal doesn't resize, so we relaunch in a CMD window.
# Would prefer to use execvpe() here, but somehow it is not working properly
# in the Windows 10 environment.
if 'IA_RELAUNCHED' not in os.environ:
args=['conhost']
args.extend([launch_command] if launch_command else [sys.argv[0]])
args.extend(sys.argv[1:])
os.environ['IA_RELAUNCHED'] = 'True'
os.execvp('conhost',args)
else:
_set_terminal_size_powershell(width,height)
os.system(f"mode con: cols={columns} lines={lines}")
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height)
import fcntl
import termios
def _set_terminal_size_powershell(width: int, height: int):
script=f'''
$pshost = get-host
$pswindow = $pshost.ui.rawui
$newsize = $pswindow.buffersize
$newsize.height = 3000
$newsize.width = {width}
$pswindow.buffersize = $newsize
$newsize = $pswindow.windowsize
$newsize.height = {height}
$newsize.width = {width}
$pswindow.windowsize = $newsize
'''
subprocess.run(["powershell","-Command","-"],input=script,text=True)
winsize = struct.pack("HHHH", lines, columns, 0, 0)
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
sys.stdout.write("\x1b[8;{rows};{cols}t".format(rows=lines, cols=columns))
sys.stdout.flush()
def _set_terminal_size_unix(width: int, height: int):
import fcntl
import termios
winsize = struct.pack("HHHH", height, width, 0, 0)
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=None):
def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines:
return
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
set_terminal_size(cols, lines, launch_command)
set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider):
def translate_value(self):
@@ -82,23 +42,7 @@ class IntSlider(npyscreen.Slider):
stri = stri.rjust(l)
return stri
# -------------------------------------
# fix npyscreen form so that cursor wraps both forward and backward
class CyclingForm(object):
def find_previous_editable(self, *args):
done = False
n = self.editw-1
while not done:
if self._widgets__[n].editable and not self._widgets__[n].hidden:
self.editw = n
done = True
n -= 1
if n<0:
if self.cycle_widgets:
n = len(self._widgets__)-1
else:
done = True
# -------------------------------------
class CenteredTitleText(npyscreen.TitleText):
def __init__(self, *args, **keywords):
@@ -149,7 +93,14 @@ class FloatSlider(npyscreen.Slider):
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
class SelectColumnBase():
class MultiSelectColumns(npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
def make_contained_widgets(self):
self._my_widgets = []
column_width = self.width // self.columns
@@ -199,242 +150,53 @@ class SelectColumnBase():
def h_cursor_line_right(self, ch):
super().h_cursor_line_down(ch)
def handle_mouse_event(self, mouse_event):
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
column_width = self.width // self.columns
column_height = math.ceil(self.value_cnt / self.columns)
column_no = rel_x // column_width
row_no = rel_y // self._contained_widget_height
self.cursor_line = column_no * column_height + row_no
if bstate & curses.BUTTON1_DOUBLE_CLICKED:
if hasattr(self,'on_mouse_double_click'):
self.on_mouse_double_click(self.cursor_line)
self.display()
class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen, values=values, **keywords)
class TextBox(npyscreen.MultiLineEdit):
def update(self, clear=True):
if clear:
self.clear()
def on_mouse_double_click(self, cursor_line):
self.h_select_toggle(cursor_line)
HEIGHT = self.height
WIDTH = self.width
# draw box.
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
self.parent.curses_pad.hline(
self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
)
self.parent.curses_pad.vline(
self.rely, self.relx, curses.ACS_VLINE, self.height
)
self.parent.curses_pad.vline(
self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
)
class SingleSelectWithChanged(npyscreen.SelectOne):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# draw corners
self.parent.curses_pad.addch(
self.rely,
self.relx,
curses.ACS_ULCORNER,
)
self.parent.curses_pad.addch(
self.rely,
self.relx + WIDTH,
curses.ACS_URCORNER,
)
self.parent.curses_pad.addch(
self.rely + HEIGHT,
self.relx,
curses.ACS_LLCORNER,
)
self.parent.curses_pad.addch(
self.rely + HEIGHT,
self.relx + WIDTH,
curses.ACS_LRCORNER,
)
def h_select(self,ch):
super().h_select(ch)
if self.on_changed:
self.on_changed(self.value)
class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
self.on_changed = None
super().__init__(screen, values=values, **keywords)
def when_value_edited(self):
self.h_select(self.cursor_line)
def when_cursor_moved(self):
self.h_select(self.cursor_line)
def h_cursor_line_right(self,ch):
self.h_exit_down('bye bye')
class TextBoxInner(npyscreen.MultiLineEdit):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
self.yank = None
self.handlers.update({
"^A": self.h_cursor_to_start,
"^E": self.h_cursor_to_end,
"^K": self.h_kill,
"^F": self.h_cursor_right,
"^B": self.h_cursor_left,
"^Y": self.h_yank,
"^V": self.h_paste,
})
def h_cursor_to_start(self, input):
self.cursor_position = 0
def h_cursor_to_end(self, input):
self.cursor_position = len(self.value)
def h_kill(self, input):
self.yank = self.value[self.cursor_position:]
self.value = self.value[:self.cursor_position]
def h_yank(self, input):
if self.yank:
self.paste(self.yank)
def paste(self, text: str):
self.value = self.value[:self.cursor_position] + text + self.value[self.cursor_position:]
self.cursor_position += len(text)
def h_paste(self, input: int=0):
try:
text = pyperclip.paste()
except ModuleNotFoundError:
text = "To paste with the mouse on Linux, please install the 'xclip' program."
self.paste(text)
def handle_mouse_event(self, mouse_event):
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
if bstate & (BUTTON2_CLICKED|BUTTON3_CLICKED):
self.h_paste()
# def update(self, clear=True):
# if clear:
# self.clear()
# HEIGHT = self.height
# WIDTH = self.width
# # draw box.
# self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
# self.parent.curses_pad.hline(
# self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
# )
# self.parent.curses_pad.vline(
# self.rely, self.relx, curses.ACS_VLINE, self.height
# )
# self.parent.curses_pad.vline(
# self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
# )
# # draw corners
# self.parent.curses_pad.addch(
# self.rely,
# self.relx,
# curses.ACS_ULCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely,
# self.relx + WIDTH,
# curses.ACS_URCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely + HEIGHT,
# self.relx,
# curses.ACS_LLCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely + HEIGHT,
# self.relx + WIDTH,
# curses.ACS_LRCORNER,
# )
# # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
# (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
# self.relx += 1
# self.rely += 1
# self.height -= 1
# self.width -= 1
# super().update(clear=False)
# (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)
class TextBox(npyscreen.BoxTitle):
_contained_widget = TextBoxInner
class BufferBox(npyscreen.BoxTitle):
_contained_widget = npyscreen.BufferPager
class ConfirmCancelPopup(fmPopup.ActionPopup):
DEFAULT_COLUMNS = 100
def on_ok(self):
self.value = True
def on_cancel(self):
self.value = False
class FileBox(npyscreen.BoxTitle):
_contained_widget = npyscreen.Filename
class PrettyTextBox(npyscreen.BoxTitle):
_contained_widget = TextBox
def _wrap_message_lines(message, line_length):
lines = []
for line in message.split('\n'):
lines.extend(textwrap.wrap(line.rstrip(), line_length))
return lines
def _prepare_message(message):
if isinstance(message, list) or isinstance(message, tuple):
return "\n".join([ s.rstrip() for s in message])
#return "\n".join(message)
else:
return message
def select_stable_diffusion_config_file(
form_color: str='DANGER',
wrap:bool =True,
model_name:str='Unknown',
):
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation."
title = "CONFIG FILE SELECTION"
options=[
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
"Skip installation for now and come back later",
"Enter config file path manually",
]
F = ConfirmCancelPopup(
name=title,
color=form_color,
cycle_widgets=True,
lines=16,
)
F.preserve_selected_widget = True
mlw = F.add(
wgmultiline.Pager,
max_height=4,
editable=False,
)
mlw_width = mlw.width-1
if wrap:
message = _wrap_message_lines(message, mlw_width)
mlw.values = message
choice = F.add(
SingleSelectWithChanged,
values = options,
value = [0],
max_height = len(options)+1,
scroll_exit=True,
)
file = F.add(
FileBox,
name='Path to config file',
max_height=3,
hidden=True,
must_exist=True,
scroll_exit=True
)
def toggle_visible(value):
value = value[0]
if value==3:
file.hidden=False
else:
file.hidden=True
F.display()
choice.on_changed = toggle_visible
F.editw = 1
F.edit()
if not F.value:
return None
assert choice.value[0] in range(0,4),'invalid choice'
choices = ['epsilon','v','abort',file.value]
return choices[choice.value[0]]
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
(relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
self.relx += 1
self.rely += 1
self.height -= 1
self.width -= 1
super().update(clear=False)
(self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)

Some files were not shown because too many files have changed in this diff Show More