mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 11:48:17 -05:00
Compare commits
21 Commits
v5.10.0dev
...
v5.10.0dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d06cc71cd9 | ||
|
|
ecbc1cf85d | ||
|
|
dcad306aee | ||
|
|
bea9d037bc | ||
|
|
eed4260975 | ||
|
|
6a79b1c64c | ||
|
|
4cdfe3e30d | ||
|
|
df294db236 | ||
|
|
52e247cfe0 | ||
|
|
0a13640bf3 | ||
|
|
643b71f56c | ||
|
|
b745411866 | ||
|
|
c3ffb0feed | ||
|
|
f53ff5fa3c | ||
|
|
b2337b56bd | ||
|
|
fb777b4502 | ||
|
|
4c12f5a011 | ||
|
|
d4655ea21a | ||
|
|
752b62d0b5 | ||
|
|
9a0efb308d | ||
|
|
b4c276b50f |
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@@ -2,11 +2,11 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
@@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Builds and uploads python build artifacts.
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build wheel
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -27,12 +27,19 @@ jobs:
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: build wheel
|
||||
id: build_wheel
|
||||
run: ./scripts/build_wheel.sh
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.build_wheel.outputs.DIST_PATH }}
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: installer
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
always_run: true
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-wheel.yml
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
10
Makefile
10
Makefile
@@ -16,7 +16,7 @@ help:
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
@@ -64,13 +64,13 @@ frontend-dev:
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Tag the release
|
||||
wheel:
|
||||
cd scripts && ./build_wheel.sh
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
||||
# Tag the release
|
||||
tag-release:
|
||||
cd scripts && ./tag_release.sh
|
||||
cd installer && ./tag_release.sh
|
||||
|
||||
# Generate the OpenAPI Schema for the app
|
||||
openapi:
|
||||
|
||||
@@ -64,7 +64,7 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
6. Determine the package specifier to use when installing. This is a performance optimization.
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
|
||||
BIN
installer/WinLongPathsEnabled.reg
Normal file
BIN
installer/WinLongPathsEnabled.reg
Normal file
Binary file not shown.
@@ -32,6 +32,12 @@ if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo -e "${BCYAN}CI environment detected${RESET}"
|
||||
echo
|
||||
else
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
@@ -71,8 +77,42 @@ fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python3 -m build --outdir ../dist/ ../.
|
||||
python3 -m build --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
echo "Building installer zip files for InvokeAI ${VERSION}..."
|
||||
echo
|
||||
|
||||
# get rid of any old ones
|
||||
rm -f *.zip
|
||||
rm -rf InvokeAI-Installer
|
||||
|
||||
# copy content
|
||||
mkdir InvokeAI-Installer
|
||||
for f in templates *.txt *.reg; do
|
||||
cp -r ${f} InvokeAI-Installer/
|
||||
done
|
||||
mkdir InvokeAI-Installer/lib
|
||||
cp lib/*.py InvokeAI-Installer/lib
|
||||
|
||||
# Install scripts
|
||||
# Mac/Linux
|
||||
cp install.sh.in InvokeAI-Installer/install.sh
|
||||
chmod a+x InvokeAI-Installer/install.sh
|
||||
|
||||
# Windows
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
FILENAME=InvokeAI-installer-$VERSION.zip
|
||||
|
||||
# Zip everything up
|
||||
zip -r ${FILENAME} InvokeAI-Installer
|
||||
|
||||
echo
|
||||
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
|
||||
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
|
||||
|
||||
# clean up, but only if we are not in a github action
|
||||
@@ -85,7 +125,9 @@ fi
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Setting GitHub action outputs..."
|
||||
echo "DIST_PATH=./dist/" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
exit 0
|
||||
128
installer/install.bat.in
Normal file
128
installer/install.bat.in
Normal file
@@ -0,0 +1,128 @@
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
@rem This script requires the user to install Python 3.10 or higher. All other
|
||||
@rem requirements are downloaded as needed.
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
@rem Config
|
||||
@rem The version in the next line is replaced by an up to date release number
|
||||
@rem when create_installer.sh is run. Change the release number there.
|
||||
set INSTRUCTIONS=https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
set TROUBLESHOOTING=https://invoke-ai.github.io/InvokeAI/help/FAQ/
|
||||
set PYTHON_URL=https://www.python.org/downloads/windows/
|
||||
set MINIMUM_PYTHON_VERSION=3.10.0
|
||||
set PYTHON_URL=https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
set err_msg=An error has occurred and the script could not continue.
|
||||
|
||||
@rem --------------------------- Intro -------------------------------
|
||||
echo This script will install InvokeAI and its dependencies.
|
||||
echo.
|
||||
echo BEFORE YOU START PLEASE MAKE SURE TO DO THE FOLLOWING
|
||||
echo 1. Install python 3.10 or 3.11. Python version 3.9 is no longer supported.
|
||||
echo 2. Double-click on the file WinLongPathsEnabled.reg in order to
|
||||
echo enable long path support on your system.
|
||||
echo 3. Install the Visual C++ core libraries.
|
||||
echo Please download and install the libraries from:
|
||||
echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
echo.
|
||||
echo See %INSTRUCTIONS% for more details.
|
||||
echo.
|
||||
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
|
||||
pause
|
||||
|
||||
@rem ---------------------------- check Python version ---------------
|
||||
echo ***** Checking and Updating Python *****
|
||||
|
||||
call python --version >.tmp1 2>.tmp2
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Please install Python 3.10-11. See %INSTRUCTIONS% for details.
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
for /f "tokens=2" %%i in (.tmp1) do set python_version=%%i
|
||||
if "%python_version%" == "" (
|
||||
set err_msg=No python was detected on your system. Please install Python version %MINIMUM_PYTHON_VERSION% or higher. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
call :compareVersions %MINIMUM_PYTHON_VERSION% %python_version%
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Your version of Python is too low. You need at least %MINIMUM_PYTHON_VERSION% but you have %python_version%. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem -------------- Install and Configure ---------------
|
||||
|
||||
call python .\lib\main.py
|
||||
pause
|
||||
exit /b
|
||||
|
||||
@rem ------------------------ Subroutines ---------------
|
||||
@rem routine to do comparison of semantic version numbers
|
||||
@rem found at https://stackoverflow.com/questions/15807762/compare-version-numbers-in-batch-file
|
||||
:compareVersions
|
||||
::
|
||||
:: Compares two version numbers and returns the result in the ERRORLEVEL
|
||||
::
|
||||
:: Returns 1 if version1 > version2
|
||||
:: 0 if version1 = version2
|
||||
:: -1 if version1 < version2
|
||||
::
|
||||
:: The nodes must be delimited by . or , or -
|
||||
::
|
||||
:: Nodes are normally strictly numeric, without a 0 prefix. A letter suffix
|
||||
:: is treated as a separate node
|
||||
::
|
||||
setlocal enableDelayedExpansion
|
||||
set "v1=%~1"
|
||||
set "v2=%~2"
|
||||
call :divideLetters v1
|
||||
call :divideLetters v2
|
||||
:loop
|
||||
call :parseNode "%v1%" n1 v1
|
||||
call :parseNode "%v2%" n2 v2
|
||||
if %n1% gtr %n2% exit /b 1
|
||||
if %n1% lss %n2% exit /b -1
|
||||
if not defined v1 if not defined v2 exit /b 0
|
||||
if not defined v1 exit /b -1
|
||||
if not defined v2 exit /b 1
|
||||
goto :loop
|
||||
|
||||
|
||||
:parseNode version nodeVar remainderVar
|
||||
for /f "tokens=1* delims=.,-" %%A in ("%~1") do (
|
||||
set "%~2=%%A"
|
||||
set "%~3=%%B"
|
||||
)
|
||||
exit /b
|
||||
|
||||
|
||||
:divideLetters versionVar
|
||||
for %%C in (a b c d e f g h i j k l m n o p q r s t u v w x y z) do set "%~1=!%~1:%%C=.%%C!"
|
||||
exit /b
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
echo The installer will exit now.
|
||||
pause
|
||||
exit /b
|
||||
|
||||
pause
|
||||
|
||||
:Trim
|
||||
SetLocal EnableDelayedExpansion
|
||||
set Params=%*
|
||||
for /f "tokens=1*" %%a in ("!Params!") do EndLocal & set %1=%%b
|
||||
exit /b
|
||||
40
installer/install.sh.in
Executable file
40
installer/install.sh.in
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
# make sure we are not already in a venv
|
||||
# (don't need to check status)
|
||||
deactivate >/dev/null 2>&1
|
||||
scriptdir=$(dirname "$0")
|
||||
cd $scriptdir
|
||||
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.10.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.11 python3.10 python3 python ; do
|
||||
if ppath=`which $candidate 2>/dev/null`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
|
||||
exec $PYTHON ./lib/main.py ${@}
|
||||
read -p "Press any key to exit"
|
||||
0
installer/lib/__init__.py
Normal file
0
installer/lib/__init__.py
Normal file
438
installer/lib/installer.py
Normal file
438
installer/lib/installer.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
InvokeAI installer script
|
||||
"""
|
||||
|
||||
import locale
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Tuple
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.10.0,<=3.11.100"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||
DOCS_URL = "https://invoke-ai.github.io/InvokeAI/"
|
||||
DISCORD_URL = "https://discord.gg/ZmtBAhwWhy"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
VERSION = "latest"
|
||||
|
||||
|
||||
def get_version_from_wheel_filename(wheel_filename: str) -> str:
|
||||
match = re.search(r"-(\d+\.\d+\.\d+)", wheel_filename)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return version
|
||||
else:
|
||||
raise ValueError(f"Could not extract version from wheel filename: {wheel_filename}")
|
||||
|
||||
|
||||
class Installer:
|
||||
"""
|
||||
Deploys an InvokeAI installation into a given path
|
||||
"""
|
||||
|
||||
reqs: list[str] = INSTALLER_REQS
|
||||
|
||||
def __init__(self) -> None:
|
||||
if os.getenv("VIRTUAL_ENV") is not None:
|
||||
print("A virtual environment is already activated. Please 'deactivate' before installation.")
|
||||
sys.exit(-1)
|
||||
self.bootstrap()
|
||||
self.available_releases = get_github_releases()
|
||||
|
||||
def mktemp_venv(self) -> TemporaryDirectory[str]:
|
||||
"""
|
||||
Creates a temporary virtual environment for the installer itself
|
||||
|
||||
:return: path to the created virtual environment directory
|
||||
:rtype: TemporaryDirectory
|
||||
"""
|
||||
|
||||
# Cleaning up temporary directories on Windows results in a race condition
|
||||
# and a stack trace.
|
||||
# `ignore_cleanup_errors` was only added in Python 3.10
|
||||
if OS == "Windows" and int(platform.python_version_tuple()[1]) >= 10:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX, ignore_cleanup_errors=True)
|
||||
else:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX)
|
||||
|
||||
venv.create(venv_dir.name, with_pip=True)
|
||||
self.venv_dir = venv_dir
|
||||
set_sys_path(Path(venv_dir.name))
|
||||
|
||||
return venv_dir
|
||||
|
||||
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory[str] | None:
|
||||
"""
|
||||
Bootstrap the installer venv with packages required at install time
|
||||
"""
|
||||
|
||||
print("Initializing the installer. This may take a minute - please wait...")
|
||||
|
||||
venv_dir = self.mktemp_venv()
|
||||
pip = get_pip_from_venv(Path(venv_dir.name))
|
||||
|
||||
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
|
||||
cmd.extend(self.reqs)
|
||||
|
||||
try:
|
||||
# upgrade pip to the latest version to avoid a confusing message
|
||||
res = upgrade_pip(Path(venv_dir.name))
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
# run the install prerequisites installation
|
||||
res = subprocess.check_output(cmd).decode()
|
||||
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
return venv_dir
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
|
||||
def app_venv(self, venv_parent: Path) -> Path:
|
||||
"""
|
||||
Create a virtualenv for the InvokeAI installation
|
||||
"""
|
||||
|
||||
venv_dir = venv_parent / ".venv"
|
||||
|
||||
# Prefer to copy python executables
|
||||
# so that updates to system python don't break InvokeAI
|
||||
try:
|
||||
venv.create(venv_dir, with_pip=True)
|
||||
# If installing over an existing environment previously created with symlinks,
|
||||
# the executables will fail to copy. Keep symlinks in that case
|
||||
except shutil.SameFileError:
|
||||
venv.create(venv_dir, with_pip=True, symlinks=True)
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(
|
||||
self,
|
||||
root: str = "~/invokeai",
|
||||
yes_to_all: bool = False,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Install the InvokeAI application into the given runtime path
|
||||
|
||||
Args:
|
||||
root: Destination path for the installation
|
||||
yes_to_all: Accept defaults to all questions
|
||||
find_links: A local directory to search for requirement wheels before going to remote indexes
|
||||
wheel: A wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
if wheel:
|
||||
messages.installing_from_wheel(wheel.name)
|
||||
version = get_version_from_wheel_filename(wheel.name)
|
||||
else:
|
||||
messages.welcome(self.available_releases)
|
||||
version = messages.choose_version(self.available_releases)
|
||||
|
||||
auto_dest = Path(os.environ.get("INVOKEAI_ROOT", root)).expanduser().resolve()
|
||||
destination = auto_dest if yes_to_all else messages.dest_path(root)
|
||||
if destination is None:
|
||||
print("Could not find or create the destination directory. Installation cancelled.")
|
||||
sys.exit(0)
|
||||
|
||||
# create the venv for the app
|
||||
self.venv = self.app_venv(venv_parent=destination)
|
||||
|
||||
self.instance = InvokeAiInstance(runtime=destination, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(extra_index_url, optional_modules, find_links, wheel)
|
||||
|
||||
# install the launch/update scripts into the runtime directory
|
||||
self.instance.install_user_scripts()
|
||||
|
||||
message = f"""
|
||||
*** Installation Successful ***
|
||||
|
||||
To start the application, run:
|
||||
{destination}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
For more information, troubleshooting and support, visit our docs at:
|
||||
{DOCS_URL}
|
||||
|
||||
Join the community on Discord:
|
||||
{DISCORD_URL}
|
||||
"""
|
||||
print(message)
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
The virtual environment *may* reside within the runtime directory.
|
||||
A single runtime directory *may* be shared by multiple virtual environments, though this isn't currently tested or supported.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str = "stable") -> None:
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
self.version = version
|
||||
|
||||
set_sys_path(venv)
|
||||
os.environ["INVOKEAI_ROOT"] = str(self.runtime.expanduser().resolve())
|
||||
os.environ["VIRTUAL_ENV"] = str(self.venv.expanduser().resolve())
|
||||
upgrade_pip(venv)
|
||||
|
||||
def get(self) -> tuple[Path, Path]:
|
||||
"""
|
||||
Get the location of the virtualenv directory for this installation
|
||||
|
||||
:return: Paths of the runtime and the venv directory
|
||||
:rtype: tuple[Path, Path]
|
||||
"""
|
||||
|
||||
return (self.runtime, self.venv)
|
||||
|
||||
def install(
|
||||
self,
|
||||
extra_index_url: Optional[str] = None,
|
||||
optional_modules: Optional[str] = None,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
):
|
||||
"""Install the package from PyPi or a wheel, if provided.
|
||||
|
||||
Args:
|
||||
extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
optional_modules: optional modules to install using "[module1,module2]" format.
|
||||
find_links: path to a directory containing wheels to be searched prior to going to the internet
|
||||
wheel: a wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
# not currently used, but may be useful for "install most recent version" option
|
||||
if self.version == "prerelease":
|
||||
version = None
|
||||
pre_flag = "--pre"
|
||||
elif self.version == "stable":
|
||||
version = None
|
||||
pre_flag = None
|
||||
else:
|
||||
version = self.version
|
||||
pre_flag = None
|
||||
|
||||
src = "invokeai"
|
||||
if optional_modules:
|
||||
src += optional_modules
|
||||
if version:
|
||||
src += f"=={version}"
|
||||
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
|
||||
from plumbum import FG, ProcessExecutionError, local
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
# Uninstall xformers if it is present; the correct version of it will be reinstalled if needed
|
||||
_ = pip["uninstall", "-yqq", "xformers"] & FG
|
||||
|
||||
pipeline = pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--force-reinstall",
|
||||
"--use-pep517",
|
||||
str(src) if not wheel else str(wheel),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre_flag if not wheel else None, # Ignore the flag if we are installing a wheel
|
||||
]
|
||||
|
||||
try:
|
||||
_ = pipeline & FG
|
||||
except ProcessExecutionError as e:
|
||||
print(f"Error: {e}")
|
||||
print(
|
||||
"Could not install InvokeAI. Please try downloading the latest version of the installer and install again."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
Copy the launch and update scripts to the runtime dir
|
||||
"""
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
|
||||
|
||||
### Utility functions ###
|
||||
|
||||
|
||||
def get_pip_from_venv(venv_path: Path) -> str:
|
||||
"""
|
||||
Given a path to a virtual environment, get the absolute path to the `pip` executable
|
||||
in a cross-platform fashion. Does not validate that the pip executable
|
||||
actually exists in the virtualenv.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
:return: Absolute path to the pip executable
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
def upgrade_pip(venv_path: Path) -> str | None:
|
||||
"""
|
||||
Upgrade the pip executable in the given virtual environment
|
||||
"""
|
||||
|
||||
python = "Scripts\\python.exe" if OS == "Windows" else "bin/python"
|
||||
python = str(venv_path.expanduser().resolve() / python)
|
||||
|
||||
try:
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||
encoding=locale.getpreferredencoding()
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_sys_path(venv_path: Path) -> None:
|
||||
"""
|
||||
Given a path to a virtual environment, set the sys.path, in a cross-platform fashion,
|
||||
such that packages from the given venv may be imported in the current process.
|
||||
Ensure that the packages from system environment are not visible (emulate
|
||||
the virtual env 'activate' script) - this doesn't work on Windows yet.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
"""
|
||||
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
# add the site-packages location to the venv
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_github_releases() -> tuple[list[str], list[str]] | None:
|
||||
"""
|
||||
Query Github for published (pre-)release versions.
|
||||
Return a tuple where the first element is a list of stable releases and the second element is a list of pre-releases.
|
||||
Return None if the query fails for any reason.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
## get latest releases using github api
|
||||
url = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
releases: list[str] = []
|
||||
pre_releases: list[str] = []
|
||||
try:
|
||||
res = requests.get(url)
|
||||
res.raise_for_status()
|
||||
tag_info = res.json()
|
||||
for tag in tag_info:
|
||||
if not tag["prerelease"]:
|
||||
releases.append(tag["tag_name"].lstrip("v"))
|
||||
else:
|
||||
pre_releases.append(tag["tag_name"].lstrip("v"))
|
||||
except requests.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Could not fetch version information from GitHub. Please check your network connection and try again.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("An unexpected error occurred while trying to fetch version information from GitHub. Please try again.")
|
||||
return
|
||||
|
||||
releases.sort(reverse=True)
|
||||
pre_releases.sort(reverse=True)
|
||||
|
||||
return releases, pre_releases
|
||||
|
||||
|
||||
def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
This is only applicable to Windows and Linux, since PyTorch does not
|
||||
offer accelerated builds for macOS.
|
||||
|
||||
Prefer CUDA-enabled wheels if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
||||
|
||||
A NoneType return means just go to PyPi.
|
||||
|
||||
:return: tuple consisting of (extra index url or None, optional modules to load or None)
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import GpuType, select_gpu
|
||||
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
|
||||
|
||||
url = None
|
||||
optional_modules: str | None = None
|
||||
if OS == "Linux":
|
||||
if device == GpuType.ROCM:
|
||||
url = "https://download.pytorch.org/whl/rocm6.1"
|
||||
elif device == GpuType.CPU:
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
elif device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif OS == "Windows":
|
||||
if device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif device.value == "cpu":
|
||||
# CPU uses the default PyPi index, no optional modules
|
||||
pass
|
||||
elif OS == "Darwin":
|
||||
# macOS uses the default PyPi index, no optional modules
|
||||
pass
|
||||
|
||||
# Fall back to defaults
|
||||
|
||||
return (url, optional_modules)
|
||||
57
installer/lib/main.py
Normal file
57
installer/lib/main.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
InvokeAI Installer
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from installer import Installer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
dest="root",
|
||||
type=str,
|
||||
help="Destination path for installation",
|
||||
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-y",
|
||||
"--yes",
|
||||
"--yes-to-all",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help="Assume default answers to all questions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--find-links",
|
||||
dest="find_links",
|
||||
help="Specifies a directory of local wheel files to be searched prior to searching the online repositories.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wheel",
|
||||
dest="wheel",
|
||||
help="Specifies a wheel for the InvokeAI package. Used for troubleshooting or testing prereleases.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
342
installer/lib/messages.py
Normal file
342
installer/lib/messages.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
Installer user interaction
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def welcome(available_releases: tuple[list[str], list[str]] | None = None) -> None:
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) is not None:
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
if available_releases is not None:
|
||||
latest_stable = available_releases[0][0]
|
||||
last_pre = available_releases[1][0]
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Latest stable release (recommended): [b bright_white]{latest_stable}", justify="center"
|
||||
)
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Last published pre-release version: [b bright_white]{last_pre}", justify="center"
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
Panel(
|
||||
title="[bold wheat1]Welcome to the InvokeAI Installer",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=Style(bgcolor="grey23", color="orange1"),
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def installing_from_wheel(wheel_filename: str) -> None:
|
||||
"""Display a message about installing from a wheel"""
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield Text.from_markup(f"You are installing from a wheel file: [bold]{wheel_filename}\n")
|
||||
yield Text.from_markup(
|
||||
"[bold orange3]If you are not sure why you are doing this, you should cancel and install InvokeAI normally."
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
title="Installing from Wheel",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
should_proceed = Confirm.ask("Do you want to proceed?")
|
||||
|
||||
if not should_proceed:
|
||||
console.print("Installation cancelled.")
|
||||
exit()
|
||||
|
||||
|
||||
def choose_version(available_releases: tuple[list[str], list[str]] | None = None) -> str:
|
||||
"""
|
||||
Prompt the user to choose an Invoke version to install
|
||||
"""
|
||||
|
||||
# short circuit if we couldn't get a version list
|
||||
# still try to install the latest stable version
|
||||
if available_releases is None:
|
||||
return "stable"
|
||||
|
||||
console.print(":grey_question: [orange3]Please choose an Invoke version to install.")
|
||||
|
||||
choices = available_releases[0] + available_releases[1]
|
||||
|
||||
response = prompt(
|
||||
message=f" <Enter> to install the recommended release ({choices[0]}). <Tab> or type to pick a version: ",
|
||||
complete_while_typing=True,
|
||||
completer=FuzzyWordCompleter(choices),
|
||||
)
|
||||
console.print(f" Version {choices[0] if response == '' else response} will be installed.")
|
||||
|
||||
console.line()
|
||||
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
print(" Is this location correct?")
|
||||
default = False
|
||||
else:
|
||||
print(f":file_folder: InvokeAI will be installed in {dest}")
|
||||
default = True
|
||||
|
||||
dest_confirmed = Confirm.ask(" Please confirm:", default=default)
|
||||
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def dest_path(dest: Optional[str | Path] = None) -> Path | None:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
:param dest: a filesystem path, defaults to None
|
||||
:type dest: str, optional
|
||||
:return: absolute path to the created installation directory
|
||||
:rtype: Path
|
||||
"""
|
||||
|
||||
if dest is not None:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
dest_confirmed = False
|
||||
|
||||
while not dest_confirmed:
|
||||
browse_start = (dest or Path.cwd()).expanduser().resolve()
|
||||
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [str(browse_start)], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
console.line()
|
||||
|
||||
console.print(f":grey_question: [orange3]Please select the install destination:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
vi_mode=True,
|
||||
complete_while_typing=True,
|
||||
# Test that this is not needed on Windows
|
||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||
)
|
||||
prev_dest = dest
|
||||
dest = Path(selected)
|
||||
|
||||
console.line()
|
||||
|
||||
dest_confirmed = confirm_install(dest.expanduser().resolve())
|
||||
|
||||
if not dest_confirmed:
|
||||
dest = prev_dest
|
||||
|
||||
dest = dest.expanduser().resolve()
|
||||
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
else:
|
||||
console.rule("Goodbye!")
|
||||
|
||||
|
||||
class GpuType(Enum):
|
||||
CUDA_WITH_XFORMERS = "xformers"
|
||||
CUDA = "cuda"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
"""
|
||||
Prompt the user to select the GPU driver
|
||||
"""
|
||||
|
||||
if ARCH == "arm64" and OS != "Darwin":
|
||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||
return GpuType.CPU
|
||||
|
||||
nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 3060 or newer GPU using CUDA",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
vintage_nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 20xx or older GPU using CUDA+xFormers",
|
||||
GpuType.CUDA_WITH_XFORMERS,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU using ROCm",
|
||||
GpuType.ROCM,
|
||||
)
|
||||
cpu = (
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
|
||||
options = []
|
||||
if OS == "Windows":
|
||||
options = [nvidia, vintage_nvidia, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, vintage_nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
options = [cpu]
|
||||
|
||||
if len(options) == 1:
|
||||
return options[0][1]
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
f"Detected the [gold1]{OS}-{ARCH}[/] platform",
|
||||
"",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/installation/requirements/[/] to ensure your system meets the minimum requirements.",
|
||||
"",
|
||||
"[red3]🠶[/] [b]Your GPU drivers must be correctly installed before using InvokeAI![/] [red3]🠴[/]",
|
||||
]
|
||||
),
|
||||
"",
|
||||
"Please select the type of GPU installed in your computer.",
|
||||
Panel(
|
||||
"\n".join([f"[dark_goldenrod b i]{i}[/] [dark_red]🢒[/]{opt[0]}" for (i, opt) in options.items()]),
|
||||
box=box.MINIMAL,
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = prompt(
|
||||
"Please make your selection: ",
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n in options.keys(), error_message="Please select one the above options"
|
||||
),
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
|
||||
|
||||
def simple_banner(message: str) -> None:
|
||||
"""
|
||||
A simple banner with a message, defined here for styling consistency
|
||||
|
||||
:param message: The message to display
|
||||
:type message: str
|
||||
"""
|
||||
|
||||
console.rule(message)
|
||||
|
||||
|
||||
# TODO this does not yet work correctly
|
||||
def windows_long_paths_registry() -> None:
|
||||
"""
|
||||
Display a message about applying the Windows long paths registry fix
|
||||
"""
|
||||
|
||||
with open(str(Path(__file__).parent / "WinLongPathsEnabled.reg"), "r", encoding="utf-16le") as code:
|
||||
syntax = Syntax(code.read(), line_numbers=True, lexer="regedit")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"We will now apply a registry fix to enable long paths on Windows. InvokeAI needs this to function correctly. We are asking your permission to modify the Windows Registry on your behalf.",
|
||||
"",
|
||||
"This is the change that will be applied:",
|
||||
str(syntax),
|
||||
]
|
||||
)
|
||||
),
|
||||
title="Windows Long Paths registry fix",
|
||||
box=box.HORIZONTALS,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
return
|
||||
return text
|
||||
52
installer/readme.txt
Normal file
52
installer/readme.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Preparations:
|
||||
|
||||
You will need to install Python 3.10 or higher for this installer
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.10.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
pip --version
|
||||
|
||||
You should get a message that indicates that the pip package
|
||||
installer was derived from Python 3.10 or 3.11. For example:
|
||||
"pip 22.0.1 from /usr/bin/pip (python 3.10)"
|
||||
|
||||
Long Paths on Windows:
|
||||
|
||||
If you are on Windows, you will need to enable Windows Long Paths to
|
||||
run InvokeAI successfully. If you're not sure what this is, you
|
||||
almost certainly need to do this.
|
||||
|
||||
Simply double-click the "WinLongPathsEnabled.reg" file located in
|
||||
this directory, and approve the Windows warnings. Note that you will
|
||||
need to have admin privileges in order to do this.
|
||||
|
||||
Launching the installer:
|
||||
|
||||
Windows: double-click the 'install.bat' file (while keeping it inside
|
||||
the InvokeAI-Installer folder).
|
||||
|
||||
Linux and Mac: Please open the terminal application and run
|
||||
'./install.sh' (while keeping it inside the InvokeAI-Installer
|
||||
folder).
|
||||
|
||||
The installer will create a directory of your choice and install the
|
||||
InvokeAI application within it. This directory contains everything you need to run
|
||||
invokeai. Once InvokeAI is up and running, you may delete the
|
||||
InvokeAI-Installer folder at your convenience.
|
||||
|
||||
For more information, please see
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
54
installer/templates/invoke.bat.in
Normal file
54
installer/templates/invoke.bat.in
Normal file
@@ -0,0 +1,54 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
setlocal
|
||||
|
||||
call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Open the developer console
|
||||
echo 3. Command-line help
|
||||
echo Q - Quit
|
||||
echo.
|
||||
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
echo.
|
||||
set /P choice="Please enter 1-4, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
goto start
|
||||
|
||||
endlocal
|
||||
pause
|
||||
|
||||
:ending
|
||||
exit /b
|
||||
87
installer/templates/invoke.sh.in
Normal file
87
installer/templates/invoke.sh.in
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
# 2. the .venv is also located in the runtime directory and is named exactly that
|
||||
#
|
||||
# If both of the above are not true, this script will likely not work as intended.
|
||||
# Activate the virtual environment and run `invoke.py` directly.
|
||||
####
|
||||
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
|
||||
# Stash the CLI args - when we prompt for user input, `$@` is overwritten
|
||||
PARAMS=$@
|
||||
|
||||
# This setting allows torch to fall back to CPU for operations that are not supported by MPS on macOS.
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest\n\n"
|
||||
read -p "Please enter 1-4, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke, and for checking if the user is in the developer console
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
do_line_input
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
printf "Press ^D to exit\n"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
@@ -15,7 +15,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
@@ -35,12 +34,6 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class ValidationRunData(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow being published.")
|
||||
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
|
||||
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -52,10 +45,6 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
validation_run_data: Optional[ValidationRunData] = Body(
|
||||
default=None,
|
||||
description="The validation run data to use for this batch. This is only used if this is a validation run.",
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
|
||||
@@ -106,7 +106,6 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
@@ -119,7 +118,6 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (
|
||||
ContentShuffleDetector,
|
||||
LeresDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_image_processor",
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_image_processor",
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_image_processor",
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_image_processor",
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"normalbae_image_processor",
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle_image_processor",
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@invocation(
|
||||
"zoe_depth_image_processor",
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_processor",
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
max_faces=self.max_faces,
|
||||
min_confidence=self.min_confidence,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"leres_image_processor",
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"tile_image_processor",
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
H, W, C = np_img.shape
|
||||
H = int(float(H) / float(down_sampling_rate))
|
||||
W = int(float(W) / float(down_sampling_rate))
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_processor",
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_image_processor",
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
|
||||
@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
|
||||
@@ -302,10 +302,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||
# newer version of the app that added a new model type.
|
||||
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
|
||||
self._logger.warning(
|
||||
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
|
||||
)
|
||||
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
|
||||
else:
|
||||
results.append(model_config)
|
||||
|
||||
|
||||
@@ -201,12 +201,6 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
return None
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@@ -243,20 +237,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
retried_from_item_id: Optional[int] = Field(
|
||||
default=None, description="The item_id of the queue item that this item was retried from"
|
||||
)
|
||||
is_api_validation_run: bool = Field(
|
||||
default=False,
|
||||
description="Whether this queue item is an API validation run.",
|
||||
)
|
||||
published_workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
|
||||
@@ -47,7 +47,6 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
query: Optional[str],
|
||||
tags: Optional[list[str]],
|
||||
has_been_opened: Optional[bool],
|
||||
is_published: Optional[bool],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -57,7 +56,6 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided categories."""
|
||||
pass
|
||||
@@ -68,7 +66,6 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided tags."""
|
||||
pass
|
||||
|
||||
@@ -67,7 +67,6 @@ class WorkflowWithoutID(BaseModel):
|
||||
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
|
||||
# it is None.
|
||||
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -102,7 +101,6 @@ class WorkflowRecordDTOBase(BaseModel):
|
||||
opened_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The opened timestamp of the workflow."
|
||||
)
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
|
||||
@@ -119,7 +119,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
query: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
@@ -242,7 +241,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
@@ -294,7 +292,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
|
||||
@@ -65,6 +65,9 @@ def apply_monkeypatches() -> None:
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
def register_mime_types() -> None:
|
||||
"""Register additional mime types for windows."""
|
||||
|
||||
@@ -5,14 +5,62 @@ import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
from invokeai.backend.image_util.util import np_to_pil
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
dwpose_image = Image.fromarray(dwpose_image)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
|
||||
class DWOpenposeDetector:
|
||||
"""
|
||||
@@ -20,6 +68,62 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
class DWOpenposeDetector2:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
|
||||
and managed by the model manager.
|
||||
"""
|
||||
|
||||
hf_repo_id = "yzd-v/DWPose"
|
||||
hf_filename_onnx_det = "yolox_l.onnx"
|
||||
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
|
||||
@@ -109,7 +213,7 @@ class DWOpenposeDetector:
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return DWOpenposeDetector.draw_pose(
|
||||
return DWOpenposeDetector2.draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@@ -126,13 +127,11 @@ def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
hsv_color = np.array([[[ie / float(len(edges)) * 180, 255, 255]]], dtype=np.uint8)
|
||||
rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0, 0]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
rgb_color.tolist(),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
|
||||
44
invokeai/backend/image_util/dw_openpose/wholebody.py
Normal file
44
invokeai/backend/image_util/dw_openpose/wholebody.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, oriImg):
|
||||
det_result = inference_detector(self.session_det, oriImg)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
245
invokeai/backend/util/mps_fixes.py
Normal file
245
invokeai/backend/util/mps_fixes.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import math
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
|
||||
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
|
||||
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
|
||||
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
|
||||
|
||||
def new_torch_interpolate(
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(
|
||||
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
||||
# TODO: refactor it
|
||||
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
|
||||
class ChunkedSlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
assert isinstance(slice_size, int)
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
self.slice_size = slice_size
|
||||
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
if self.slice_size != 1 or attn.upcast_attention:
|
||||
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
chunk_tmp_tensor = torch.empty(
|
||||
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
self.get_attention_scores_chunked(
|
||||
attn,
|
||||
query_slice,
|
||||
key_slice,
|
||||
attn_mask_slice,
|
||||
hidden_states[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
chunk_tmp_tensor,
|
||||
)
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
||||
# batch size = 1
|
||||
assert query.shape[0] == 1
|
||||
assert key.shape[0] == 1
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
# out_item_size = query.dtype.itemsize
|
||||
# if attn.upcast_attention:
|
||||
# out_item_size = torch.float32.itemsize
|
||||
out_item_size = query.element_size()
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2**29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
chunk_step = max(1, int(query.shape[1] / chunks_count))
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
def _get_chunk_view(tensor, start, length):
|
||||
if start + length > tensor.shape[1]:
|
||||
length = tensor.shape[1] - start
|
||||
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:, start : start + length]
|
||||
|
||||
for chunk_pos in range(0, query.shape[1], chunk_step):
|
||||
if attention_mask is not None:
|
||||
torch.baddbmm(
|
||||
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=1,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
else:
|
||||
torch.baddbmm(
|
||||
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
chunk = chunk.softmax(dim=-1)
|
||||
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
||||
|
||||
# del chunk
|
||||
|
||||
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|
||||
@@ -150,7 +150,7 @@
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"tsafe": "^1.8.5",
|
||||
"tsafe": "^1.7.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^6.1.0",
|
||||
|
||||
8
invokeai/frontend/web/pnpm-lock.yaml
generated
8
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -284,8 +284,8 @@ devDependencies:
|
||||
specifier: ^8.3.4
|
||||
version: 8.3.4
|
||||
tsafe:
|
||||
specifier: ^1.8.5
|
||||
version: 1.8.5
|
||||
specifier: ^1.7.5
|
||||
version: 1.7.5
|
||||
type-fest:
|
||||
specifier: ^4.26.1
|
||||
version: 4.26.1
|
||||
@@ -8791,8 +8791,8 @@ packages:
|
||||
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
|
||||
dev: false
|
||||
|
||||
/tsafe@1.8.5:
|
||||
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
|
||||
/tsafe@1.7.5:
|
||||
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
|
||||
dev: true
|
||||
|
||||
/tsconfck@3.1.5(typescript@5.6.2):
|
||||
|
||||
@@ -1706,7 +1706,6 @@
|
||||
"noRecentWorkflows": "No Recent Workflows",
|
||||
"private": "Private",
|
||||
"shared": "Shared",
|
||||
"published": "Published",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
@@ -1784,39 +1783,7 @@
|
||||
"textPlaceholder": "Empty Text",
|
||||
"workflowBuilderAlphaWarning": "The workflow builder is currently in alpha. There may be breaking changes before the stable release.",
|
||||
"minimum": "Minimum",
|
||||
"maximum": "Maximum",
|
||||
"publish": "Publish",
|
||||
"published": "Published",
|
||||
"unpublish": "Unpublish",
|
||||
"workflowLocked": "Workflow Locked",
|
||||
"workflowLockedPublished": "Published workflows are locked for editing.\nYou can unpublish the workflow to edit it, or make a copy of it.",
|
||||
"workflowLockedDuringPublishing": "Workflow is locked while configuring for publishing.",
|
||||
"selectOutputNode": "Select Output Node",
|
||||
"changeOutputNode": "Change Output Node",
|
||||
"publishedWorkflowOutputs": "Outputs",
|
||||
"publishedWorkflowInputs": "Inputs",
|
||||
"unpublishableInputs": "These unpublishable inputs will be omitted",
|
||||
"noPublishableInputs": "No publishable inputs",
|
||||
"noOutputNodeSelected": "No output node selected",
|
||||
"cannotPublish": "Cannot publish workflow",
|
||||
"publishWarnings": "Warnings",
|
||||
"errorWorkflowHasUnsavedChanges": "Workflow has unsaved changes",
|
||||
"errorWorkflowHasBatchOrGeneratorNodes": "Workflow has batch and/or generator nodes",
|
||||
"errorWorkflowHasInvalidGraph": "Workflow graph invalid (hover Invoke button for details)",
|
||||
"errorWorkflowHasNoOutputNode": "No output node selected",
|
||||
"warningWorkflowHasNoPublishableInputFields": "No publishable input fields selected - published workflow will run with only default values",
|
||||
"warningWorkflowHasUnpublishableInputFields": "Workflow has some unpublishable inputs - these will be omitted from the published workflow",
|
||||
"publishFailed": "Publish failed",
|
||||
"publishFailedDesc": "There was a problem publishing the workflow. Please try again.",
|
||||
"publishSuccess": "Your workflow is being published",
|
||||
"publishSuccessDesc": "Check your <LinkComponent>Project Dashboard</LinkComponent> to see its progress.",
|
||||
"publishInProgress": "Publishing in progress",
|
||||
"publishedWorkflowIsLocked": "Published workflow is locked",
|
||||
"publishingValidationRun": "Publishing Validation Run",
|
||||
"publishingValidationRunInProgress": "Publishing validation run in progress.",
|
||||
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow.",
|
||||
"selectingOutputNode": "Selecting output node",
|
||||
"selectingOutputNodeDesc": "Click a node to select it as the workflow's output node."
|
||||
"maximum": "Maximum"
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
|
||||
7
invokeai/frontend/web/src/app/store/actions.ts
Normal file
7
invokeai/frontend/web/src/app/store/actions.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
|
||||
export const enqueueRequested = createAction<{
|
||||
tabName: TabName;
|
||||
prepend: boolean;
|
||||
}>('app/enqueueRequested');
|
||||
@@ -10,6 +10,7 @@ import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/l
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
@@ -62,6 +63,7 @@ addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedNodes(startAppListening);
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
|
||||
@@ -5,7 +5,7 @@ import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAd
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { EnqueueBatchArg, ImageDTO } from 'services/api/types';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('queue');
|
||||
@@ -19,7 +19,7 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
||||
const { imageDTO } = action.payload;
|
||||
const state = getState();
|
||||
|
||||
const enqueueBatchArg: EnqueueBatchArg = {
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: await buildAdHocPostProcessingGraph({
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
@@ -17,11 +17,10 @@ import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
const state = getState();
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import {
|
||||
$outputNodeId,
|
||||
getPublishInputs,
|
||||
selectFieldIdentifiersWithInvocationTypes,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeData, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { groupBy } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, EnqueueBatchArg } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
|
||||
const log = logger('generation');
|
||||
|
||||
export const useEnqueueWorkflows = () => {
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
async (prepend: boolean, isApiValidationRun: boolean) => {
|
||||
dispatch(enqueueRequestedWorkflows());
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const workflow = state.workflow;
|
||||
@@ -95,7 +91,7 @@ export const useEnqueueWorkflows = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow: builtWorkflow,
|
||||
@@ -104,57 +100,18 @@ export const useEnqueueWorkflows = () => {
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend,
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
if (isApiValidationRun) {
|
||||
// Derive the input fields from the builder's selected node field elements
|
||||
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
|
||||
const inputs = getPublishInputs(fieldIdentifiers, templates);
|
||||
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName }) => {
|
||||
return {
|
||||
kind: 'input',
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
} as const;
|
||||
});
|
||||
|
||||
// Derive the output fields from the builder's selected output node
|
||||
const outputNodeId = $outputNodeId.get();
|
||||
assert(outputNodeId !== null, 'Output node not selected');
|
||||
const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type;
|
||||
const outputNodeTemplate = templates[outputNodeType];
|
||||
assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`);
|
||||
const outputFieldNames = Object.keys(outputNodeTemplate.outputs);
|
||||
const api_output_fields = outputFieldNames.map((fieldName) => {
|
||||
return {
|
||||
kind: 'output',
|
||||
node_id: outputNodeId,
|
||||
field_name: fieldName,
|
||||
} as const;
|
||||
});
|
||||
|
||||
assert(workflow.id, 'Workflow without ID cannot be used for API validation run');
|
||||
|
||||
batchConfig.validation_run_data = {
|
||||
workflow_id: workflow.id,
|
||||
input_fields: api_input_fields,
|
||||
output_fields: api_output_fields,
|
||||
};
|
||||
|
||||
// If the batch is an API validation run, we only want to run it once
|
||||
batchConfig.batch.runs = 1;
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
return { batchConfig, enqueueResult };
|
||||
},
|
||||
[dispatch, getState]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
});
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
@@ -9,11 +9,10 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedUpscaling,
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
@@ -3,7 +3,6 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
@@ -176,7 +175,6 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
.concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
|
||||
@@ -74,7 +74,6 @@ export type AppConfig = {
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -14,7 +14,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invoke',
|
||||
category: 'app',
|
||||
callback: queue.enqueueBack,
|
||||
callback: queue.queueBack,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
@@ -26,7 +26,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invokeFront',
|
||||
category: 'app',
|
||||
callback: queue.enqueueFront,
|
||||
callback: queue.queueFront,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
|
||||
@@ -54,7 +54,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
|
||||
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -291,7 +291,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
*/
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: EnqueueBatchArg = {
|
||||
const batch: BatchConfig = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: graph.getGraph(),
|
||||
|
||||
@@ -2,9 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
|
||||
import { TopCenterPanel } from 'features/nodes/components/flow/panels/TopPanel/TopCenterPanel';
|
||||
import { TopLeftPanel } from 'features/nodes/components/flow/panels/TopPanel/TopLeftPanel';
|
||||
import { TopRightPanel } from 'features/nodes/components/flow/panels/TopPanel/TopRightPanel';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import WorkflowEditorSettings from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -34,9 +32,7 @@ const NodeEditor = () => {
|
||||
<>
|
||||
<Flow />
|
||||
<AddNodeCmdk />
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<TopPanel />
|
||||
<BottomLeftPanel />
|
||||
<MinimapPanel />
|
||||
</>
|
||||
|
||||
@@ -18,7 +18,6 @@ import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$cursorPos,
|
||||
@@ -147,7 +146,6 @@ export const AddNodeCmdk = memo(() => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const addNode = useAddNode();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
// Filtering the list is expensive - debounce the search term to avoid stutters
|
||||
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
|
||||
const isOpen = useStore($addNodeCmdk);
|
||||
@@ -162,8 +160,8 @@ export const AddNodeCmdk = memo(() => {
|
||||
id: 'addNode',
|
||||
category: 'workflows',
|
||||
callback: open,
|
||||
options: { enabled: tab === 'workflows' && !isLocked, preventDefault: true },
|
||||
dependencies: [open, tab, isLocked],
|
||||
options: { enabled: tab === 'workflows', preventDefault: true },
|
||||
dependencies: [open, tab],
|
||||
});
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -4,7 +4,6 @@ import type {
|
||||
EdgeChange,
|
||||
HandleType,
|
||||
NodeChange,
|
||||
NodeMouseHandler,
|
||||
OnEdgesChange,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
@@ -17,10 +16,8 @@ import type {
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useNodeCopyPaste } from 'features/nodes/hooks/useNodeCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import {
|
||||
@@ -47,7 +44,7 @@ import {
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import { type AnyEdge, type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -95,8 +92,6 @@ export const Flow = memo(() => {
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const store = useAppStore();
|
||||
const isWorkflowsFocused = useIsRegionFocused('workflows');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
useFocusRegion('workflows', flowWrapper);
|
||||
|
||||
useSyncExecutionState();
|
||||
@@ -220,7 +215,7 @@ export const Flow = memo(() => {
|
||||
id: 'copySelection',
|
||||
category: 'workflows',
|
||||
callback: copySelection,
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
options: { preventDefault: true },
|
||||
dependencies: [copySelection],
|
||||
});
|
||||
|
||||
@@ -249,24 +244,24 @@ export const Flow = memo(() => {
|
||||
id: 'selectAll',
|
||||
category: 'workflows',
|
||||
callback: selectAll,
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused, isLocked],
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelection',
|
||||
category: 'workflows',
|
||||
callback: pasteSelection,
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelection, isLocked, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelection],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelectionWithEdges',
|
||||
category: 'workflows',
|
||||
callback: pasteSelectionWithEdges,
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges, isLocked, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -275,8 +270,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(undo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo, isLocked, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -285,8 +280,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(redo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo, isLocked, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo],
|
||||
});
|
||||
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
@@ -323,22 +318,10 @@ export const Flow = memo(() => {
|
||||
id: 'deleteSelection',
|
||||
category: 'workflows',
|
||||
callback: deleteSelection,
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused && !isLocked },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused, isLocked],
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
|
||||
if (!$isSelectingOutputNode.get()) {
|
||||
return;
|
||||
}
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const { id } = node.data;
|
||||
$outputNodeId.set(id);
|
||||
$isSelectingOutputNode.set(false);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ReactFlow<AnyNode, AnyEdge>
|
||||
id="workflow-editor"
|
||||
@@ -349,7 +332,6 @@ export const Flow = memo(() => {
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodeClick={onNodeClick}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
@@ -362,12 +344,6 @@ export const Flow = memo(() => {
|
||||
onMoveEnd={handleMoveEnd}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
isValidConnection={isValidConnection}
|
||||
edgesFocusable={!isLocked}
|
||||
edgesReconnectable={!isLocked}
|
||||
nodesDraggable={!isLocked}
|
||||
nodesConnectable={!isLocked}
|
||||
nodesFocusable={!isLocked}
|
||||
elementsSelectable={!isLocked}
|
||||
minZoom={0.1}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
snapGrid={snapGrid}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { map } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
@@ -19,7 +19,7 @@ const collapsedHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useInvocationNodeNotes } from 'features/nodes/hooks/useNodeNotes';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -27,9 +27,9 @@ InvocationNodeInfoIcon.displayName = 'InvocationNodeInfoIcon';
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const notes = useInvocationNodeNotes(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const label = useNodeLabel(nodeId);
|
||||
const version = useNodeVersion(nodeId);
|
||||
const nodeTemplate = useNodeTemplateOrThrow(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ChangeEvent } from 'react';
|
||||
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
|
||||
const Content = memo(({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const description = useInputFieldUserDescriptionSafe(nodeId, fieldName);
|
||||
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));
|
||||
|
||||
@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
|
||||
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useRef } from 'react';
|
||||
@@ -100,7 +100,7 @@ const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemp
|
||||
const draggableRef = useRef<HTMLDivElement>(null);
|
||||
const dragHandleRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const isDragging = useNodeFieldDnd(nodeId, fieldName, fieldTemplate, draggableRef, dragHandleRef);
|
||||
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
|
||||
@@ -7,8 +7,7 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
@@ -106,16 +105,9 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -138,7 +130,6 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionError !== null) {
|
||||
@@ -149,13 +140,7 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -17,7 +17,7 @@ import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invoca
|
||||
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
|
||||
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
|
||||
@@ -9,8 +9,8 @@ import {
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
@@ -43,8 +43,8 @@ interface Props {
|
||||
export const InputFieldTitle = memo((props: Props) => {
|
||||
const { nodeId, fieldName, isInvalid, isDragging } = props;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const label = useInputFieldUserTitleSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const label = useInputFieldLabelSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
@@ -106,17 +105,9 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -139,7 +130,6 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionErrorTKey !== null) {
|
||||
@@ -150,13 +140,7 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -3,8 +3,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -17,10 +17,10 @@ type Props = {
|
||||
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const label = useNodeLabel(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ChakraProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
@@ -63,12 +62,6 @@ const containerSx: SystemStyleObject = {
|
||||
display: 'block',
|
||||
shadow: '0 0 0 3px var(--invoke-colors-blue-300)',
|
||||
},
|
||||
'&[data-is-editor-locked="true"]': {
|
||||
'& *': {
|
||||
cursor: 'not-allowed',
|
||||
pointerEvents: 'none',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const shadowsSx: SystemStyleObject = {
|
||||
@@ -105,8 +98,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const zoomToNode = useZoomToNode();
|
||||
|
||||
const executionState = useNodeExecutionState(nodeId);
|
||||
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||
@@ -134,9 +126,9 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
// This target is marked as not fitting the view on double click
|
||||
return;
|
||||
}
|
||||
zoomToNode();
|
||||
zoomToNode(nodeId);
|
||||
},
|
||||
[zoomToNode]
|
||||
[nodeId, zoomToNode]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -149,7 +141,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
sx={containerSx}
|
||||
width={width || NODE_WIDTH}
|
||||
opacity={opacity}
|
||||
data-is-editor-locked={isLocked}
|
||||
data-is-selected={selected}
|
||||
data-is-mouse-over-form-field={mouseOverFormField.isMouseOverFormField}
|
||||
>
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const TopCenterPanel = memo(() => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={2} left="50%" transform="translateX(-50%)" position="absolute" pointerEvents="none">
|
||||
{!!name.length && <WorkflowName />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TopCenterPanel.displayName = 'TopCenterPanel';
|
||||
@@ -1,64 +0,0 @@
|
||||
import { Alert, AlertDescription, AlertIcon, AlertTitle, Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isSelectingOutputNode,
|
||||
useIsValidationRunInProgress,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const TopLeftPanel = memo(() => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isValidationRunInProgress = useIsValidationRunInProgress();
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
{!isLocked && (
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
)}
|
||||
{isLocked && (
|
||||
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<Box>
|
||||
<AlertTitle>{t('workflows.builder.workflowLocked')}</AlertTitle>
|
||||
{isValidationRunInProgress && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.publishingValidationRunInProgress')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && !isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedDuringPublishing')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.selectingOutputNodeDesc')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isPublished && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedPublished')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
</Box>
|
||||
</Alert>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopLeftPanel.displayName = 'TopLeftPanel';
|
||||
@@ -0,0 +1,40 @@
|
||||
import { Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
<Spacer />
|
||||
{!!name.length && <WorkflowName />}
|
||||
<Spacer />
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopCenterPanel);
|
||||
@@ -1,34 +0,0 @@
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
export const TopRightPanel = memo(() => {
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (isLocked) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2} top={2} right={2} position="absolute" alignItems="flex-end" pointerEvents="none">
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopRightPanel.displayName = 'TopRightPanel';
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { HorizontalResizeHandle } from 'features/ui/components/tabs/ResizeHandle';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -22,21 +23,23 @@ export const EditModeLeftPanelContent = memo(() => {
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
<ScrollableContent>
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold, PiLockOpenBold } from 'react-icons/pi';
|
||||
|
||||
export const PublishedWorkflowPanelContent = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveAs = useSaveOrSaveAsWorkflow();
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" gap={2} alignItems="center">
|
||||
<Heading size="md" pt={32}>
|
||||
{t('workflows.builder.workflowLocked')}
|
||||
</Heading>
|
||||
<Text fontSize="md">{t('workflows.builder.publishedWorkflowsLocked')}</Text>
|
||||
<Button size="md" onClick={saveAs} variant="ghost" leftIcon={<PiCopyBold />}>
|
||||
{t('common.saveAs')}
|
||||
</Button>
|
||||
<Button size="md" onClick={undefined} variant="ghost" leftIcon={<PiLockOpenBold />}>
|
||||
{t('workflows.builder.unpublish')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishedWorkflowPanelContent.displayName = 'PublishedWorkflowPanelContent';
|
||||
@@ -2,7 +2,7 @@ import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowListMenuTrigger } from 'features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger';
|
||||
import { WorkflowViewEditToggleButton } from 'features/nodes/components/sidePanel/WorkflowViewEditToggleButton';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { WorkflowLibraryMenu } from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -10,13 +10,12 @@ import SaveWorkflowButton from './SaveWorkflowButton';
|
||||
|
||||
export const ActiveWorkflowNameAndActions = memo(() => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={1} minW={0}>
|
||||
<WorkflowListMenuTrigger />
|
||||
<Spacer />
|
||||
{mode === 'edit' && !isPublished && <SaveWorkflowButton />}
|
||||
{mode === 'edit' && <SaveWorkflowButton />}
|
||||
<WorkflowViewEditToggleButton />
|
||||
<WorkflowLibraryMenu />
|
||||
</Flex>
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EditModeLeftPanelContent } from 'features/nodes/components/sidePanel/EditModeLeftPanelContent';
|
||||
import { PublishedWorkflowPanelContent } from 'features/nodes/components/sidePanel/PublishedWorkflowPanelContent';
|
||||
import { $isInPublishFlow } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { PublishWorkflowPanelContent } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { ActiveWorkflowDescription } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowDescription';
|
||||
import { ActiveWorkflowNameAndActions } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowNameAndActions';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { ViewModeLeftPanelContent } from './viewMode/ViewModeLeftPanelContent';
|
||||
|
||||
const WorkflowsTabLeftPanel = () => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" gap={2} flexDir="column">
|
||||
{isInPublishFlow && <PublishWorkflowPanelContent />}
|
||||
{!isInPublishFlow && <ActiveWorkflowNameAndActions />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
{isPublished && <PublishedWorkflowPanelContent />}
|
||||
<ActiveWorkflowNameAndActions />
|
||||
{mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -67,8 +67,11 @@ FormElementEditModeHeader.displayName = 'FormElementEditModeHeader';
|
||||
const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const { t } = useTranslation();
|
||||
const { nodeId } = element.data.fieldIdentifier;
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const zoomToNode = useZoomToNode();
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const onClick = useCallback(() => {
|
||||
zoomToNode(nodeId);
|
||||
}, [nodeId, zoomToNode]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -76,7 +79,7 @@ const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
tooltip={t('workflows.builder.zoomToNode')}
|
||||
aria-label={t('workflows.builder.zoomToNode')}
|
||||
onClick={zoomToNode}
|
||||
onClick={onClick}
|
||||
icon={<PiGpsFixBold />}
|
||||
variant="link"
|
||||
size="sm"
|
||||
|
||||
@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -13,7 +13,7 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
|
||||
return (
|
||||
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
|
||||
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
|
||||
<NodeFieldElementOverlay nodeId={el.data.fieldIdentifier.nodeId} />
|
||||
<NodeFieldElementOverlay element={el} />
|
||||
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
|
||||
</Flex>
|
||||
);
|
||||
@@ -105,9 +105,9 @@ const nodeFieldOverlaySx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
export const NodeFieldElementOverlay = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const NodeFieldElementOverlay = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const mouseOverNode = useMouseOverNode(element.data.fieldIdentifier.nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(element.data.fieldIdentifier.nodeId);
|
||||
|
||||
return (
|
||||
<Box
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -12,7 +12,7 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
|
||||
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
|
||||
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import {
|
||||
isFloatFieldInputTemplate,
|
||||
|
||||
@@ -5,9 +5,8 @@ import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplateSafe';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -37,7 +36,7 @@ const useFormatFallbackLabel = () => {
|
||||
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const containerCtx = useContainerContext();
|
||||
const formatFallbackLabel = useFormatFallbackLabel();
|
||||
@@ -70,7 +69,7 @@ NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
|
||||
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _description = useMemo(
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import type { DropTargetRecord } from '@atlaskit/pragmatic-drag-and-drop/dist/types/internal-types';
|
||||
import type { ElementDragPayload } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import {
|
||||
draggable,
|
||||
dropTargetForElements,
|
||||
@@ -35,7 +33,7 @@ import {
|
||||
selectFormRootElementId,
|
||||
selectWorkflowSlice,
|
||||
} from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { FieldIdentifier, FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { ElementId, FormElement } from 'features/nodes/types/workflow';
|
||||
import { buildNodeFieldElement, isContainerElement } from 'features/nodes/types/workflow';
|
||||
import type { RefObject } from 'react';
|
||||
@@ -60,27 +58,6 @@ const isFormElementDndData = (data: Record<string | symbol, unknown>): data is F
|
||||
return uniqueFormElementDndKey in data;
|
||||
};
|
||||
|
||||
const uniqueNodeFieldDndKey = Symbol('node-field');
|
||||
type NodeFieldDndData = {
|
||||
[uniqueNodeFieldDndKey]: true;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
const buildNodeFieldDndData = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate
|
||||
): NodeFieldDndData => ({
|
||||
[uniqueNodeFieldDndKey]: true,
|
||||
nodeId,
|
||||
fieldName,
|
||||
fieldTemplate,
|
||||
});
|
||||
const isNodeFieldDndData = (data: Record<string | symbol, unknown>): data is NodeFieldDndData => {
|
||||
return uniqueNodeFieldDndKey in data;
|
||||
};
|
||||
|
||||
/**
|
||||
* Flashes an element by changing its background color. Used to indicate that an element has been moved.
|
||||
* @param elementId The id of the element to flash
|
||||
@@ -156,27 +133,6 @@ const useGetInitialValue = () => {
|
||||
return _getInitialValue;
|
||||
};
|
||||
|
||||
const getSourceElement = (source: ElementDragPayload) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
const { nodeId, fieldName, fieldTemplate } = source.data;
|
||||
return buildNodeFieldElement(nodeId, fieldName, fieldTemplate.type);
|
||||
}
|
||||
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const getTargetElement = (target: DropTargetRecord) => {
|
||||
if (isFormElementDndData(target.data)) {
|
||||
return target.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Singleton hook that monitors for builder dnd events and dispatches actions accordingly.
|
||||
*/
|
||||
@@ -200,20 +156,20 @@ export const useBuilderDndMonitor = () => {
|
||||
|
||||
useEffect(() => {
|
||||
return monitorForElements({
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data) || isNodeFieldDndData(source.data),
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data),
|
||||
onDrop: ({ location, source }) => {
|
||||
const target = location.current.dropTargets[0];
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceElement = getSourceElement(source);
|
||||
const targetElement = getTargetElement(target);
|
||||
|
||||
if (!sourceElement || !targetElement) {
|
||||
if (!isFormElementDndData(source.data) || !isFormElementDndData(target.data)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceElement = source.data.element;
|
||||
const targetElement = target.data.element;
|
||||
|
||||
if (sourceElement.id === targetElement.id) {
|
||||
// Dropping on self is a no-op
|
||||
return;
|
||||
@@ -403,15 +359,8 @@ export const useFormElementDnd = (
|
||||
element: draggableElement,
|
||||
// TODO(psyche): This causes a kinda jittery behaviour - need a better heuristic to determine stickiness
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element.id !== getElement(elementId).parentId;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
canDrop: ({ source }) =>
|
||||
isFormElementDndData(source.data) && source.data.element.id !== getElement(elementId).parentId,
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(elementId);
|
||||
|
||||
@@ -474,16 +423,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
dropTargetForElements({
|
||||
element: droppableElement,
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) => {
|
||||
const rootElement = getElement(rootElementId, isContainerElement);
|
||||
if (rootElement.data.children.length !== 0) {
|
||||
return false;
|
||||
}
|
||||
if (isNodeFieldDndData(source.data) || isFormElementDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
canDrop: ({ source }) =>
|
||||
getElement(rootElementId, isContainerElement).data.children.length === 0 && isFormElementDndData(source.data),
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(rootElementId, isContainerElement);
|
||||
|
||||
@@ -514,8 +455,7 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
/**
|
||||
* Hook that provides dnd functionality for node fields.
|
||||
*
|
||||
* @param nodeId: The id of the node
|
||||
* @param fieldName: The name of the field
|
||||
* @param fieldIdentifier The identifier of the node field
|
||||
* @param fieldTemplate The template of the node field, required to build the form element
|
||||
* @param draggableRef The ref of the draggable HTML element
|
||||
* @param dragHandleRef The ref of the drag handle HTML element
|
||||
@@ -523,8 +463,7 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
* @returns Whether the node field is currently being dragged
|
||||
*/
|
||||
export const useNodeFieldDnd = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldIdentifier: FieldIdentifier,
|
||||
fieldTemplate: FieldInputTemplate,
|
||||
draggableRef: RefObject<HTMLElement>,
|
||||
dragHandleRef: RefObject<HTMLElement>
|
||||
@@ -542,7 +481,12 @@ export const useNodeFieldDnd = (
|
||||
draggable({
|
||||
element: draggableElement,
|
||||
dragHandle: dragHandleElement,
|
||||
getInitialData: () => buildNodeFieldDndData(nodeId, fieldName, fieldTemplate),
|
||||
getInitialData: () => {
|
||||
const { nodeId, fieldName } = fieldIdentifier;
|
||||
const { type } = fieldTemplate;
|
||||
const element = buildNodeFieldElement(nodeId, fieldName, type);
|
||||
return buildFormElementDndData(element);
|
||||
},
|
||||
onDragStart: () => {
|
||||
setIsDragging(true);
|
||||
},
|
||||
@@ -551,7 +495,7 @@ export const useNodeFieldDnd = (
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [dragHandleRef, draggableRef, fieldName, fieldTemplate, nodeId]);
|
||||
}, [dragHandleRef, draggableRef, fieldIdentifier, fieldTemplate]);
|
||||
|
||||
return isDragging;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
|
||||
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import { InvocationNodeNotesTextarea } from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeNotesTextarea';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
@@ -36,7 +36,7 @@ export default memo(InspectorDetailsTab);
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const version = useNodeVersion(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const needsUpdate = useNodeNeedsUpdate(nodeId);
|
||||
|
||||
return (
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -37,7 +37,7 @@ const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const nes = useNodeExecutionState(nodeId);
|
||||
|
||||
if (!nes || nes.outputs.length === 0) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -14,8 +14,8 @@ type Props = {
|
||||
|
||||
const InspectorTabEditableNodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const label = useNodeLabel(nodeId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const onChange = useCallback(
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -29,7 +29,7 @@ export default memo(NodeTemplateInspector);
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
|
||||
return <DataViewer data={template} label={t('nodes.nodeTemplate')} bg="base.850" color="base.200" />;
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -1,445 +0,0 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Divider,
|
||||
Flex,
|
||||
ListItem,
|
||||
Spacer,
|
||||
Text,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { ExternalLink } from 'features/gallery/components/ImageViewer/NoContentForViewer';
|
||||
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isReadyToDoValidationRun,
|
||||
$isSelectingOutputNode,
|
||||
$outputNodeId,
|
||||
$validationRunBatchId,
|
||||
usePublishInputs,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleOrThrow } from 'features/nodes/hooks/useInputFieldUserTitleOrThrow';
|
||||
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
|
||||
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
|
||||
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
|
||||
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowLineRightBold, PiLightningFill, PiXBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const PublishWorkflowPanelContent = memo(() => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} h="full">
|
||||
<ButtonGroup isAttached={false} size="sm" variant="ghost">
|
||||
<Spacer />
|
||||
<CancelPublishButton />
|
||||
<PublishWorkflowButton />
|
||||
</ButtonGroup>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
<OutputFields />
|
||||
<PublishableInputFields />
|
||||
<UnpublishableInputFields />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishWorkflowPanelContent.displayName = 'PublishWorkflowPanelContent';
|
||||
|
||||
const OutputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Flex alignItems="center">
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
|
||||
<Spacer />
|
||||
<SelectOutputNodeButton variant="link" size="sm" />
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
{!outputNodeId && (
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
{t('workflows.builder.noOutputNodeSelected')}
|
||||
</Text>
|
||||
)}
|
||||
{outputNodeId && <OutputFieldsContent outputNodeId={outputNodeId} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
OutputFields.displayName = 'OutputFields';
|
||||
|
||||
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
|
||||
const outputFieldNames = useOutputFieldNames(outputNodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
{outputFieldNames.map((fieldName) => (
|
||||
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
OutputFieldsContent.displayName = 'OutputFieldsContent';
|
||||
|
||||
const PublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.publishable.length === 0) {
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.noPublishableInputs')}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowInputs')}</Text>
|
||||
<Divider />
|
||||
{inputs.publishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishableInputFields.displayName = 'PublishableInputFields';
|
||||
|
||||
const UnpublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.unpublishable.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.unpublishableInputs')}
|
||||
</Text>
|
||||
<Divider />
|
||||
{inputs.unpublishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
|
||||
|
||||
const SelectOutputNodeButton = memo((props: ButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const onClick = useCallback(() => {
|
||||
$outputNodeId.set(null);
|
||||
$isSelectingOutputNode.set(true);
|
||||
}, []);
|
||||
return (
|
||||
<Button
|
||||
leftIcon={<PiArrowLineRightBold />}
|
||||
isDisabled={isSelectingOutputNode}
|
||||
tooltip={isSelectingOutputNode ? t('workflows.builder.selectingOutputNodeDesc') : undefined}
|
||||
onClick={onClick}
|
||||
{...props}
|
||||
>
|
||||
{isSelectingOutputNode
|
||||
? t('workflows.builder.selectingOutputNode')
|
||||
: outputNodeId
|
||||
? t('workflows.builder.changeOutputNode')
|
||||
: t('workflows.builder.selectOutputNode')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
SelectOutputNodeButton.displayName = 'SelectOutputNodeButton';
|
||||
|
||||
const CancelPublishButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(false);
|
||||
$isSelectingOutputNode.set(false);
|
||||
$outputNodeId.set(null);
|
||||
}, []);
|
||||
return (
|
||||
<Button leftIcon={<PiXBold />} onClick={onClick}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
CancelPublishButton.displayName = 'CancelDeployButton';
|
||||
|
||||
const PublishWorkflowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isReadyToDoValidationRun = useStore($isReadyToDoValidationRun);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const inputs = usePublishInputs();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
const projectUrl = useStore($projectUrl);
|
||||
|
||||
const enqueue = useEnqueueWorkflows();
|
||||
const onClick = useCallback(async () => {
|
||||
const result = await withResultAsync(() => enqueue(true, true));
|
||||
if (result.isErr()) {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_FAILED',
|
||||
status: 'error',
|
||||
title: t('workflows.builder.publishFailed'),
|
||||
description: t('workflows.builder.publishFailedDesc'),
|
||||
duration: null,
|
||||
});
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
} else {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_SUCCESSFUL',
|
||||
status: 'success',
|
||||
title: t('workflows.builder.publishSuccess'),
|
||||
description: (
|
||||
<Trans
|
||||
i18nKey="workflows.builder.publishSuccessDesc"
|
||||
components={{
|
||||
LinkComponent: <ExternalLink href={projectUrl ?? ''} />,
|
||||
}}
|
||||
/>
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
assert(result.value.enqueueResult.batch.batch_id);
|
||||
$validationRunBatchId.set(result.value.enqueueResult.batch.batch_id);
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
}
|
||||
}, [enqueue, projectUrl, t]);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={outputNodeId !== null && !isSelectingOutputNode}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
leftIcon={<PiLightningFill />}
|
||||
isDisabled={
|
||||
!allowPublishWorkflows ||
|
||||
!isReadyToEnqueue ||
|
||||
!isWorkflowSaved ||
|
||||
hasBatchOrGeneratorNodes ||
|
||||
!isReadyToDoValidationRun ||
|
||||
!(outputNodeId !== null && !isSelectingOutputNode)
|
||||
}
|
||||
onClick={onClick}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
PublishWorkflowButton.displayName = 'DoValidationRunButton';
|
||||
|
||||
const NodeInputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldUserTitle = useInputFieldUserTitleOrThrow(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldUserTitle || fieldTemplateTitle}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeInputFieldPreview.displayName = 'NodeInputFieldPreview';
|
||||
|
||||
const NodeOutputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate.title}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
|
||||
|
||||
export const StartPublishFlowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(true);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={true}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
onClick={onClick}
|
||||
leftIcon={<PiLightningFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
isDisabled={!allowPublishWorkflows || !isReadyToEnqueue || !isWorkflowSaved || hasBatchOrGeneratorNodes}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
|
||||
StartPublishFlowButton.displayName = 'StartPublishFlowButton';
|
||||
|
||||
const PublishTooltip = memo(
|
||||
({
|
||||
isWorkflowSaved,
|
||||
hasBatchOrGeneratorNodes,
|
||||
isReadyToEnqueue,
|
||||
hasOutputNode,
|
||||
hasPublishableInputs,
|
||||
hasUnpublishableInputs,
|
||||
children,
|
||||
}: PropsWithChildren<{
|
||||
isWorkflowSaved: boolean;
|
||||
hasBatchOrGeneratorNodes: boolean;
|
||||
isReadyToEnqueue: boolean;
|
||||
hasOutputNode: boolean;
|
||||
hasPublishableInputs: boolean;
|
||||
hasUnpublishableInputs: boolean;
|
||||
}>) => {
|
||||
const { t } = useTranslation();
|
||||
const warnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
if (!hasPublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasNoPublishableInputFields'));
|
||||
}
|
||||
if (hasUnpublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasUnpublishableInputFields'));
|
||||
}
|
||||
return _warnings;
|
||||
}, [hasPublishableInputs, hasUnpublishableInputs, t]);
|
||||
const errors = useMemo(() => {
|
||||
const _errors: string[] = [];
|
||||
if (!isWorkflowSaved) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasUnsavedChanges'));
|
||||
}
|
||||
if (hasBatchOrGeneratorNodes) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasBatchOrGeneratorNodes'));
|
||||
}
|
||||
if (!isReadyToEnqueue) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasInvalidGraph'));
|
||||
}
|
||||
if (!hasOutputNode) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasNoOutputNode'));
|
||||
}
|
||||
return _errors;
|
||||
}, [hasBatchOrGeneratorNodes, hasOutputNode, isReadyToEnqueue, isWorkflowSaved, t]);
|
||||
|
||||
if (errors.length === 0 && warnings.length === 0) {
|
||||
return children;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
{errors.length > 0 && (
|
||||
<>
|
||||
<Text color="error.700" fontWeight="semibold">
|
||||
{t('workflows.builder.cannotPublish')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{errors.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
{warnings.length > 0 && (
|
||||
<>
|
||||
<Text color="warning.700" fontWeight="semibold">
|
||||
{t('workflows.builder.publishWarnings')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{warnings.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
);
|
||||
PublishTooltip.displayName = 'PublishTooltip';
|
||||
@@ -1,23 +0,0 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockBold } from 'react-icons/pi';
|
||||
|
||||
export const LockedWorkflowIcon = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
|
||||
<IconButton
|
||||
size="sm"
|
||||
cursor="not-allowed"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('workflows.builder.publishedWorkflowsLocked')}
|
||||
icon={<PiLockBold />}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
LockedWorkflowIcon.displayName = 'LockedWorkflowIcon';
|
||||
@@ -26,7 +26,6 @@ import {
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@@ -40,12 +39,13 @@ export const WorkflowLibrarySideNav = () => {
|
||||
const { t } = useTranslation();
|
||||
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
return (
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
|
||||
<Flex flexDir="column" w="full" pb={2} gap={2}>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
|
||||
{categoryOptions.includes('project') && (
|
||||
<Collapse in={view === 'yours' || view === 'shared' || view === 'private'}>
|
||||
@@ -60,9 +60,6 @@ export const WorkflowLibrarySideNav = () => {
|
||||
</Flex>
|
||||
</Collapse>
|
||||
)}
|
||||
{allowPublishWorkflows && (
|
||||
<WorkflowLibraryViewButton view="published">{t('workflows.published')}</WorkflowLibraryViewButton>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
|
||||
<BrowseWorkflowsButton />
|
||||
|
||||
@@ -36,8 +36,6 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
|
||||
return ['user'];
|
||||
case 'shared':
|
||||
return ['project'];
|
||||
case 'published':
|
||||
return ['user', 'project', 'default'];
|
||||
default:
|
||||
assert<Equals<typeof view, never>>(false);
|
||||
}
|
||||
@@ -68,7 +66,6 @@ const useInfiniteQueryAry = () => {
|
||||
query: debouncedSearchTerm,
|
||||
tags: view === 'defaults' ? selectedTags : [],
|
||||
has_been_opened: getHasBeenOpened(view),
|
||||
is_published: view === 'published' ? true : undefined,
|
||||
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
|
||||
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
|
||||
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
|
||||
import { selectWorkflowId, workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
@@ -55,6 +54,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
position="relative"
|
||||
role="button"
|
||||
onClick={handleClickLoad}
|
||||
cursor="pointer"
|
||||
bg="base.750"
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
@@ -81,7 +81,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
<Flex gap={2} alignItems="flex-start" justifyContent="space-between" w="full">
|
||||
<Text noOfLines={2}>{workflow.name}</Text>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{isActive && !workflow.is_published && (
|
||||
{isActive && (
|
||||
<Badge
|
||||
color="invokeBlue.400"
|
||||
borderColor="invokeBlue.700"
|
||||
@@ -93,18 +93,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
{t('workflows.opened')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.is_published && (
|
||||
<Badge
|
||||
color="invokeGreen.400"
|
||||
borderColor="invokeGreen.700"
|
||||
borderWidth={1}
|
||||
bg="transparent"
|
||||
flexShrink={0}
|
||||
variant="subtle"
|
||||
>
|
||||
{t('workflows.builder.published')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
|
||||
{workflow.category === 'default' && (
|
||||
<Image
|
||||
@@ -131,10 +119,8 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</Text>
|
||||
)}
|
||||
<Spacer />
|
||||
{workflow.category === 'default' && !workflow.is_published && (
|
||||
<ViewWorkflow workflowId={workflow.workflow_id} />
|
||||
)}
|
||||
{workflow.category !== 'default' && !workflow.is_published && (
|
||||
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
|
||||
{workflow.category !== 'default' && (
|
||||
<>
|
||||
<EditWorkflow workflowId={workflow.workflow_id} />
|
||||
<DownloadWorkflow workflowId={workflow.workflow_id} />
|
||||
@@ -142,7 +128,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</>
|
||||
)}
|
||||
{workflow.category === 'project' && <ShareWorkflowButton workflow={workflow} />}
|
||||
{workflow.is_published && <LockedWorkflowIcon />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
|
||||
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -11,15 +8,12 @@ import WorkflowJSONTab from './WorkflowJSONTab';
|
||||
|
||||
const WorkflowFieldsLinearViewPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
return (
|
||||
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
<Tab>{t('workflows.builder.builder')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
<Spacer />
|
||||
{allowPublishWorkflows && <StartPublishFlowButton />}
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="full" pt={2}>
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowFormNodeFieldFieldIdentifiersDeduped } from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isBoardFieldType } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetBatchStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const $isInPublishFlow = atom(false);
|
||||
export const $outputNodeId = atom<string | null>(null);
|
||||
export const $isSelectingOutputNode = atom(false);
|
||||
export const $isReadyToDoValidationRun = computed(
|
||||
[$isInPublishFlow, $outputNodeId, $isSelectingOutputNode],
|
||||
(isInPublishFlow, outputNodeId, isSelectingOutputNode) => {
|
||||
return isInPublishFlow && outputNodeId !== null && !isSelectingOutputNode;
|
||||
}
|
||||
);
|
||||
export const $validationRunBatchId = atom<string | null>(null);
|
||||
|
||||
export const useIsValidationRunInProgress = () => {
|
||||
const validationRunBatchId = useStore($validationRunBatchId);
|
||||
const { isValidationRunInProgress } = useGetBatchStatusQuery(
|
||||
validationRunBatchId ? { batch_id: validationRunBatchId } : skipToken,
|
||||
{
|
||||
selectFromResult: ({ currentData }) => {
|
||||
if (!currentData) {
|
||||
return { isValidationRunInProgress: false };
|
||||
}
|
||||
if (currentData && currentData.in_progress > 0) {
|
||||
return { isValidationRunInProgress: true };
|
||||
}
|
||||
return { isValidationRunInProgress: false };
|
||||
},
|
||||
}
|
||||
);
|
||||
return validationRunBatchId !== null || isValidationRunInProgress;
|
||||
};
|
||||
|
||||
export const selectFieldIdentifiersWithInvocationTypes = createSelector(
|
||||
selectWorkflowFormNodeFieldFieldIdentifiersDeduped,
|
||||
selectNodesSlice,
|
||||
(fieldIdentifiers, nodes) => {
|
||||
const result: { nodeId: string; fieldName: string; type: string }[] = [];
|
||||
for (const fieldIdentifier of fieldIdentifiers) {
|
||||
const node = nodes.nodes.find((node) => node.id === fieldIdentifier.nodeId);
|
||||
assert(isInvocationNode(node), `Node ${fieldIdentifier.nodeId} not found`);
|
||||
result.push({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName, type: node.data.type });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
);
|
||||
|
||||
export const getPublishInputs = (fieldIdentifiers: (FieldIdentifier & { type: string })[], templates: Templates) => {
|
||||
// Certain field types are not allowed to be input fields on a published workflow
|
||||
const publishable: FieldIdentifier[] = [];
|
||||
const unpublishable: FieldIdentifier[] = [];
|
||||
for (const fieldIdentifier of fieldIdentifiers) {
|
||||
const fieldTemplate = templates[fieldIdentifier.type]?.inputs[fieldIdentifier.fieldName];
|
||||
if (!fieldTemplate) {
|
||||
unpublishable.push(fieldIdentifier);
|
||||
continue;
|
||||
}
|
||||
if (isBoardFieldType(fieldTemplate.type)) {
|
||||
unpublishable.push(fieldIdentifier);
|
||||
continue;
|
||||
}
|
||||
publishable.push(fieldIdentifier);
|
||||
}
|
||||
return { publishable, unpublishable };
|
||||
};
|
||||
|
||||
export const usePublishInputs = () => {
|
||||
const templates = useStore($templates);
|
||||
const fieldIdentifiersWithInvocationTypes = useAppSelector(selectFieldIdentifiersWithInvocationTypes);
|
||||
const fieldIdentifiers = useMemo(
|
||||
() => getPublishInputs(fieldIdentifiersWithInvocationTypes, templates),
|
||||
[fieldIdentifiersWithInvocationTypes, templates]
|
||||
);
|
||||
|
||||
return fieldIdentifiers;
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
@@ -11,7 +11,7 @@ import { useMemo } from 'react';
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldUserDescriptionSafe = (nodeId: string, fieldName: string) => {
|
||||
export const useInputFieldDescriptionSafe = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
@@ -4,21 +4,21 @@ import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/s
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined title of an input field for a given node.
|
||||
* Gets the user-defined label of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an empty string is returned.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldUserTitleSafe = (nodeId: string, fieldName: string): string => {
|
||||
export const useInputFieldLabelSafe = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? ''),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
const label = useAppSelector(selector);
|
||||
|
||||
return title;
|
||||
return label;
|
||||
};
|
||||
@@ -1,11 +1,10 @@
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
const isConnectionInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
|
||||
@@ -20,7 +19,7 @@ const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesMissing = (nodeId: string) => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const node = useNodeData(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const instanceFields = new Set(Object.keys(node.inputs));
|
||||
@@ -31,7 +30,7 @@ export const useInputFieldNamesMissing = (nodeId: string) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const anyOrDirectFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
@@ -45,7 +44,7 @@ export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesConnection = (nodeId: string) => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const connectionFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
@@ -14,7 +13,7 @@ import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
* @throws Will throw an error if the template for the input field is not found.
|
||||
*/
|
||||
export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string): FieldInputTemplate => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => {
|
||||
const _fieldTemplate = template.inputs[fieldName];
|
||||
assert(_fieldTemplate, `Template for input field ${fieldName} not found.`);
|
||||
@@ -22,3 +21,17 @@ export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string):
|
||||
}, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
*/
|
||||
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
@@ -1,17 +0,0 @@
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
*/
|
||||
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const template = useNodeTemplateSafe(nodeId);
|
||||
const fieldTemplate = useMemo(() => template?.inputs[fieldName] ?? null, [fieldName, template?.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
@@ -1,10 +1,9 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useInputFieldTemplateTitleOrThrow = (nodeId: string, fieldName: string): string => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
export const useInputFieldTemplateTitle = (nodeId: string, fieldName: string): string => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
|
||||
const title = useMemo(() => {
|
||||
const fieldTemplate = template.inputs[fieldName];
|
||||
@@ -1,23 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined title of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an error is thrown.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldUserTitleOrThrow = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).label),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
|
||||
return title;
|
||||
};
|
||||
@@ -1,10 +1,9 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isBatchNodeType, isGeneratorNodeType } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useIsExecutableNode = (nodeId: string) => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const isExecutableNode = useMemo(
|
||||
() => !isBatchNodeType(template.type) && !isGeneratorNodeType(template.type),
|
||||
[template]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user