mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 10:48:12 -05:00
Compare commits
88 Commits
psyche/fea
...
v5.10.0dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f34d6099f5 | ||
|
|
ef9d832b6a | ||
|
|
6c87ea58b0 | ||
|
|
0e569364ac | ||
|
|
bb6e22606b | ||
|
|
3e200a2ba2 | ||
|
|
4610b55a5d | ||
|
|
b3b3dbd92d | ||
|
|
6c36b0508b | ||
|
|
2756c539e0 | ||
|
|
a34383d460 | ||
|
|
77f22497d2 | ||
|
|
5967d4e1da | ||
|
|
1253ad5053 | ||
|
|
5aa08ab09b | ||
|
|
6ce527768b | ||
|
|
fe88012236 | ||
|
|
8609b98217 | ||
|
|
19f0bf828c | ||
|
|
26cbeccfdf | ||
|
|
b5be81b97b | ||
|
|
f14d07968b | ||
|
|
525a89900a | ||
|
|
d8df31a8ac | ||
|
|
380a41be34 | ||
|
|
e990afbccb | ||
|
|
c591478d24 | ||
|
|
30def6a9bd | ||
|
|
6cf88a601d | ||
|
|
5e14545c32 | ||
|
|
eefbcd2485 | ||
|
|
13cc44a22c | ||
|
|
2cca339a5c | ||
|
|
0a7cf6c0ec | ||
|
|
06abc1d40a | ||
|
|
2cde86b7b8 | ||
|
|
0a49463c79 | ||
|
|
f3402b6ce7 | ||
|
|
5d3fb822c5 | ||
|
|
9e70d8eb6e | ||
|
|
402758d502 | ||
|
|
b97cc51f23 | ||
|
|
f6f33b5999 | ||
|
|
cd873f1fe5 | ||
|
|
5f3d398074 | ||
|
|
e6b366ff61 | ||
|
|
bcd50ed688 | ||
|
|
a5966c3197 | ||
|
|
f28b054872 | ||
|
|
31681f4ad7 | ||
|
|
aaf042de48 | ||
|
|
c28e685409 | ||
|
|
d6ac822a1f | ||
|
|
f0a4d7ac7f | ||
|
|
04b0e658df | ||
|
|
68845f4d85 | ||
|
|
6df5614b54 | ||
|
|
0bd6f0245b | ||
|
|
6c9165046e | ||
|
|
2b5da91beb | ||
|
|
74bede14be | ||
|
|
04ea3c491a | ||
|
|
38e7b23d18 | ||
|
|
c052846e05 | ||
|
|
af3a31dfec | ||
|
|
571710fab6 | ||
|
|
a175a5c252 | ||
|
|
8b3c36c6fa | ||
|
|
b9ffacd4bf | ||
|
|
ae45fc8a74 | ||
|
|
85db9c65e5 | ||
|
|
ddddaef7ca | ||
|
|
e4678201cb | ||
|
|
d66fdfde71 | ||
|
|
08ee08557b | ||
|
|
496f1262c6 | ||
|
|
188d52e4a5 | ||
|
|
db03c196a1 | ||
|
|
6bc36b697d | ||
|
|
b7d71d3028 | ||
|
|
fa1ebd9d2f | ||
|
|
eed5d02069 | ||
|
|
3650d91045 | ||
|
|
6c7d08cacb | ||
|
|
bb1c40f222 | ||
|
|
d26b7a1a12 | ||
|
|
c9992914d6 | ||
|
|
3f12a43e75 |
@@ -1,9 +1,11 @@
|
||||
*
|
||||
!invokeai
|
||||
!pyproject.toml
|
||||
!uv.lock
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
**/dist
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
||||
**/*.egg-info
|
||||
|
||||
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@@ -2,11 +2,11 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
@@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
|
||||
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
|
||||
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@@ -97,6 +97,8 @@ jobs:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
build-args: |
|
||||
GPU_DRIVER=${{ matrix.gpu-driver }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
# Builds and uploads python build artifacts.
|
||||
|
||||
name: build installer
|
||||
name: build wheel
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
@@ -27,19 +27,12 @@ jobs:
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
- name: build wheel
|
||||
id: build_wheel
|
||||
run: ./scripts/build_wheel.sh
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: installer
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
||||
path: ${{ steps.build_wheel.outputs.DIST_PATH }}
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
always_run: true
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
uses: ./.github/workflows/build-wheel.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
10
Makefile
10
Makefile
@@ -16,7 +16,7 @@ help:
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
@@ -64,13 +64,13 @@ frontend-dev:
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
# Tag the release
|
||||
wheel:
|
||||
cd scripts && ./build_wheel.sh
|
||||
|
||||
# Tag the release
|
||||
tag-release:
|
||||
cd installer && ./tag_release.sh
|
||||
cd scripts && ./tag_release.sh
|
||||
|
||||
# Generate the OpenAPI Schema for the app
|
||||
openapi:
|
||||
|
||||
@@ -1,77 +1,6 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
build-essential \
|
||||
git
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_PYTHON=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt && \
|
||||
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
|
||||
USER ubuntu
|
||||
|
||||
# Install python
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
uv python install ${PYTHON_VERSION}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync --no-install-project
|
||||
|
||||
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
|
||||
COPY invokeai invokeai
|
||||
COPY pyproject.toml .
|
||||
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync
|
||||
|
||||
|
||||
#### Build the Web UI ------------------------------------
|
||||
#### Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -85,69 +14,89 @@ RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
## Backend ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
FROM library/ubuntu:24.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev
|
||||
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cu124" \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
ARG GPU_DRIVER=cuda
|
||||
|
||||
# Install `uv` for package management
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv sync --frozen
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
@@ -64,7 +64,7 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
6. Determine the package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
|
||||
=== "Windows"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
Binary file not shown.
@@ -1,128 +0,0 @@
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
@rem This script requires the user to install Python 3.10 or higher. All other
|
||||
@rem requirements are downloaded as needed.
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
@rem Config
|
||||
@rem The version in the next line is replaced by an up to date release number
|
||||
@rem when create_installer.sh is run. Change the release number there.
|
||||
set INSTRUCTIONS=https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
set TROUBLESHOOTING=https://invoke-ai.github.io/InvokeAI/help/FAQ/
|
||||
set PYTHON_URL=https://www.python.org/downloads/windows/
|
||||
set MINIMUM_PYTHON_VERSION=3.10.0
|
||||
set PYTHON_URL=https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
set err_msg=An error has occurred and the script could not continue.
|
||||
|
||||
@rem --------------------------- Intro -------------------------------
|
||||
echo This script will install InvokeAI and its dependencies.
|
||||
echo.
|
||||
echo BEFORE YOU START PLEASE MAKE SURE TO DO THE FOLLOWING
|
||||
echo 1. Install python 3.10 or 3.11. Python version 3.9 is no longer supported.
|
||||
echo 2. Double-click on the file WinLongPathsEnabled.reg in order to
|
||||
echo enable long path support on your system.
|
||||
echo 3. Install the Visual C++ core libraries.
|
||||
echo Please download and install the libraries from:
|
||||
echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
|
||||
echo.
|
||||
echo See %INSTRUCTIONS% for more details.
|
||||
echo.
|
||||
echo FOR THE BEST USER EXPERIENCE WE SUGGEST MAXIMIZING THIS WINDOW NOW.
|
||||
pause
|
||||
|
||||
@rem ---------------------------- check Python version ---------------
|
||||
echo ***** Checking and Updating Python *****
|
||||
|
||||
call python --version >.tmp1 2>.tmp2
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Please install Python 3.10-11. See %INSTRUCTIONS% for details.
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
for /f "tokens=2" %%i in (.tmp1) do set python_version=%%i
|
||||
if "%python_version%" == "" (
|
||||
set err_msg=No python was detected on your system. Please install Python version %MINIMUM_PYTHON_VERSION% or higher. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
call :compareVersions %MINIMUM_PYTHON_VERSION% %python_version%
|
||||
if %errorlevel% == 1 (
|
||||
set err_msg=Your version of Python is too low. You need at least %MINIMUM_PYTHON_VERSION% but you have %python_version%. We recommend Python 3.10.12 from %PYTHON_URL%
|
||||
goto err_exit
|
||||
)
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem -------------- Install and Configure ---------------
|
||||
|
||||
call python .\lib\main.py
|
||||
pause
|
||||
exit /b
|
||||
|
||||
@rem ------------------------ Subroutines ---------------
|
||||
@rem routine to do comparison of semantic version numbers
|
||||
@rem found at https://stackoverflow.com/questions/15807762/compare-version-numbers-in-batch-file
|
||||
:compareVersions
|
||||
::
|
||||
:: Compares two version numbers and returns the result in the ERRORLEVEL
|
||||
::
|
||||
:: Returns 1 if version1 > version2
|
||||
:: 0 if version1 = version2
|
||||
:: -1 if version1 < version2
|
||||
::
|
||||
:: The nodes must be delimited by . or , or -
|
||||
::
|
||||
:: Nodes are normally strictly numeric, without a 0 prefix. A letter suffix
|
||||
:: is treated as a separate node
|
||||
::
|
||||
setlocal enableDelayedExpansion
|
||||
set "v1=%~1"
|
||||
set "v2=%~2"
|
||||
call :divideLetters v1
|
||||
call :divideLetters v2
|
||||
:loop
|
||||
call :parseNode "%v1%" n1 v1
|
||||
call :parseNode "%v2%" n2 v2
|
||||
if %n1% gtr %n2% exit /b 1
|
||||
if %n1% lss %n2% exit /b -1
|
||||
if not defined v1 if not defined v2 exit /b 0
|
||||
if not defined v1 exit /b -1
|
||||
if not defined v2 exit /b 1
|
||||
goto :loop
|
||||
|
||||
|
||||
:parseNode version nodeVar remainderVar
|
||||
for /f "tokens=1* delims=.,-" %%A in ("%~1") do (
|
||||
set "%~2=%%A"
|
||||
set "%~3=%%B"
|
||||
)
|
||||
exit /b
|
||||
|
||||
|
||||
:divideLetters versionVar
|
||||
for %%C in (a b c d e f g h i j k l m n o p q r s t u v w x y z) do set "%~1=!%~1:%%C=.%%C!"
|
||||
exit /b
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
echo The installer will exit now.
|
||||
pause
|
||||
exit /b
|
||||
|
||||
pause
|
||||
|
||||
:Trim
|
||||
SetLocal EnableDelayedExpansion
|
||||
set Params=%*
|
||||
for /f "tokens=1*" %%a in ("!Params!") do EndLocal & set %1=%%b
|
||||
exit /b
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# make sure we are not already in a venv
|
||||
# (don't need to check status)
|
||||
deactivate >/dev/null 2>&1
|
||||
scriptdir=$(dirname "$0")
|
||||
cd $scriptdir
|
||||
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.10.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.11 python3.10 python3 python ; do
|
||||
if ppath=`which $candidate 2>/dev/null`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
|
||||
exec $PYTHON ./lib/main.py ${@}
|
||||
read -p "Press any key to exit"
|
||||
@@ -1,438 +0,0 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
InvokeAI installer script
|
||||
"""
|
||||
|
||||
import locale
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Tuple
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.10.0,<=3.11.100"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||
DOCS_URL = "https://invoke-ai.github.io/InvokeAI/"
|
||||
DISCORD_URL = "https://discord.gg/ZmtBAhwWhy"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
VERSION = "latest"
|
||||
|
||||
|
||||
def get_version_from_wheel_filename(wheel_filename: str) -> str:
|
||||
match = re.search(r"-(\d+\.\d+\.\d+)", wheel_filename)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
return version
|
||||
else:
|
||||
raise ValueError(f"Could not extract version from wheel filename: {wheel_filename}")
|
||||
|
||||
|
||||
class Installer:
|
||||
"""
|
||||
Deploys an InvokeAI installation into a given path
|
||||
"""
|
||||
|
||||
reqs: list[str] = INSTALLER_REQS
|
||||
|
||||
def __init__(self) -> None:
|
||||
if os.getenv("VIRTUAL_ENV") is not None:
|
||||
print("A virtual environment is already activated. Please 'deactivate' before installation.")
|
||||
sys.exit(-1)
|
||||
self.bootstrap()
|
||||
self.available_releases = get_github_releases()
|
||||
|
||||
def mktemp_venv(self) -> TemporaryDirectory[str]:
|
||||
"""
|
||||
Creates a temporary virtual environment for the installer itself
|
||||
|
||||
:return: path to the created virtual environment directory
|
||||
:rtype: TemporaryDirectory
|
||||
"""
|
||||
|
||||
# Cleaning up temporary directories on Windows results in a race condition
|
||||
# and a stack trace.
|
||||
# `ignore_cleanup_errors` was only added in Python 3.10
|
||||
if OS == "Windows" and int(platform.python_version_tuple()[1]) >= 10:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX, ignore_cleanup_errors=True)
|
||||
else:
|
||||
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX)
|
||||
|
||||
venv.create(venv_dir.name, with_pip=True)
|
||||
self.venv_dir = venv_dir
|
||||
set_sys_path(Path(venv_dir.name))
|
||||
|
||||
return venv_dir
|
||||
|
||||
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory[str] | None:
|
||||
"""
|
||||
Bootstrap the installer venv with packages required at install time
|
||||
"""
|
||||
|
||||
print("Initializing the installer. This may take a minute - please wait...")
|
||||
|
||||
venv_dir = self.mktemp_venv()
|
||||
pip = get_pip_from_venv(Path(venv_dir.name))
|
||||
|
||||
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
|
||||
cmd.extend(self.reqs)
|
||||
|
||||
try:
|
||||
# upgrade pip to the latest version to avoid a confusing message
|
||||
res = upgrade_pip(Path(venv_dir.name))
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
# run the install prerequisites installation
|
||||
res = subprocess.check_output(cmd).decode()
|
||||
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
return venv_dir
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
|
||||
def app_venv(self, venv_parent: Path) -> Path:
|
||||
"""
|
||||
Create a virtualenv for the InvokeAI installation
|
||||
"""
|
||||
|
||||
venv_dir = venv_parent / ".venv"
|
||||
|
||||
# Prefer to copy python executables
|
||||
# so that updates to system python don't break InvokeAI
|
||||
try:
|
||||
venv.create(venv_dir, with_pip=True)
|
||||
# If installing over an existing environment previously created with symlinks,
|
||||
# the executables will fail to copy. Keep symlinks in that case
|
||||
except shutil.SameFileError:
|
||||
venv.create(venv_dir, with_pip=True, symlinks=True)
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(
|
||||
self,
|
||||
root: str = "~/invokeai",
|
||||
yes_to_all: bool = False,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Install the InvokeAI application into the given runtime path
|
||||
|
||||
Args:
|
||||
root: Destination path for the installation
|
||||
yes_to_all: Accept defaults to all questions
|
||||
find_links: A local directory to search for requirement wheels before going to remote indexes
|
||||
wheel: A wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
if wheel:
|
||||
messages.installing_from_wheel(wheel.name)
|
||||
version = get_version_from_wheel_filename(wheel.name)
|
||||
else:
|
||||
messages.welcome(self.available_releases)
|
||||
version = messages.choose_version(self.available_releases)
|
||||
|
||||
auto_dest = Path(os.environ.get("INVOKEAI_ROOT", root)).expanduser().resolve()
|
||||
destination = auto_dest if yes_to_all else messages.dest_path(root)
|
||||
if destination is None:
|
||||
print("Could not find or create the destination directory. Installation cancelled.")
|
||||
sys.exit(0)
|
||||
|
||||
# create the venv for the app
|
||||
self.venv = self.app_venv(venv_parent=destination)
|
||||
|
||||
self.instance = InvokeAiInstance(runtime=destination, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(extra_index_url, optional_modules, find_links, wheel)
|
||||
|
||||
# install the launch/update scripts into the runtime directory
|
||||
self.instance.install_user_scripts()
|
||||
|
||||
message = f"""
|
||||
*** Installation Successful ***
|
||||
|
||||
To start the application, run:
|
||||
{destination}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
For more information, troubleshooting and support, visit our docs at:
|
||||
{DOCS_URL}
|
||||
|
||||
Join the community on Discord:
|
||||
{DISCORD_URL}
|
||||
"""
|
||||
print(message)
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
The virtual environment *may* reside within the runtime directory.
|
||||
A single runtime directory *may* be shared by multiple virtual environments, though this isn't currently tested or supported.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str = "stable") -> None:
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
self.version = version
|
||||
|
||||
set_sys_path(venv)
|
||||
os.environ["INVOKEAI_ROOT"] = str(self.runtime.expanduser().resolve())
|
||||
os.environ["VIRTUAL_ENV"] = str(self.venv.expanduser().resolve())
|
||||
upgrade_pip(venv)
|
||||
|
||||
def get(self) -> tuple[Path, Path]:
|
||||
"""
|
||||
Get the location of the virtualenv directory for this installation
|
||||
|
||||
:return: Paths of the runtime and the venv directory
|
||||
:rtype: tuple[Path, Path]
|
||||
"""
|
||||
|
||||
return (self.runtime, self.venv)
|
||||
|
||||
def install(
|
||||
self,
|
||||
extra_index_url: Optional[str] = None,
|
||||
optional_modules: Optional[str] = None,
|
||||
find_links: Optional[str] = None,
|
||||
wheel: Optional[Path] = None,
|
||||
):
|
||||
"""Install the package from PyPi or a wheel, if provided.
|
||||
|
||||
Args:
|
||||
extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
optional_modules: optional modules to install using "[module1,module2]" format.
|
||||
find_links: path to a directory containing wheels to be searched prior to going to the internet
|
||||
wheel: a wheel file to install
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
# not currently used, but may be useful for "install most recent version" option
|
||||
if self.version == "prerelease":
|
||||
version = None
|
||||
pre_flag = "--pre"
|
||||
elif self.version == "stable":
|
||||
version = None
|
||||
pre_flag = None
|
||||
else:
|
||||
version = self.version
|
||||
pre_flag = None
|
||||
|
||||
src = "invokeai"
|
||||
if optional_modules:
|
||||
src += optional_modules
|
||||
if version:
|
||||
src += f"=={version}"
|
||||
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
|
||||
from plumbum import FG, ProcessExecutionError, local
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
# Uninstall xformers if it is present; the correct version of it will be reinstalled if needed
|
||||
_ = pip["uninstall", "-yqq", "xformers"] & FG
|
||||
|
||||
pipeline = pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--force-reinstall",
|
||||
"--use-pep517",
|
||||
str(src) if not wheel else str(wheel),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre_flag if not wheel else None, # Ignore the flag if we are installing a wheel
|
||||
]
|
||||
|
||||
try:
|
||||
_ = pipeline & FG
|
||||
except ProcessExecutionError as e:
|
||||
print(f"Error: {e}")
|
||||
print(
|
||||
"Could not install InvokeAI. Please try downloading the latest version of the installer and install again."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
Copy the launch and update scripts to the runtime dir
|
||||
"""
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
|
||||
|
||||
### Utility functions ###
|
||||
|
||||
|
||||
def get_pip_from_venv(venv_path: Path) -> str:
|
||||
"""
|
||||
Given a path to a virtual environment, get the absolute path to the `pip` executable
|
||||
in a cross-platform fashion. Does not validate that the pip executable
|
||||
actually exists in the virtualenv.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
:return: Absolute path to the pip executable
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
def upgrade_pip(venv_path: Path) -> str | None:
|
||||
"""
|
||||
Upgrade the pip executable in the given virtual environment
|
||||
"""
|
||||
|
||||
python = "Scripts\\python.exe" if OS == "Windows" else "bin/python"
|
||||
python = str(venv_path.expanduser().resolve() / python)
|
||||
|
||||
try:
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||
encoding=locale.getpreferredencoding()
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_sys_path(venv_path: Path) -> None:
|
||||
"""
|
||||
Given a path to a virtual environment, set the sys.path, in a cross-platform fashion,
|
||||
such that packages from the given venv may be imported in the current process.
|
||||
Ensure that the packages from system environment are not visible (emulate
|
||||
the virtual env 'activate' script) - this doesn't work on Windows yet.
|
||||
|
||||
:param venv_path: Path to the virtual environment
|
||||
:type venv_path: Path
|
||||
"""
|
||||
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
# add the site-packages location to the venv
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_github_releases() -> tuple[list[str], list[str]] | None:
|
||||
"""
|
||||
Query Github for published (pre-)release versions.
|
||||
Return a tuple where the first element is a list of stable releases and the second element is a list of pre-releases.
|
||||
Return None if the query fails for any reason.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
## get latest releases using github api
|
||||
url = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
releases: list[str] = []
|
||||
pre_releases: list[str] = []
|
||||
try:
|
||||
res = requests.get(url)
|
||||
res.raise_for_status()
|
||||
tag_info = res.json()
|
||||
for tag in tag_info:
|
||||
if not tag["prerelease"]:
|
||||
releases.append(tag["tag_name"].lstrip("v"))
|
||||
else:
|
||||
pre_releases.append(tag["tag_name"].lstrip("v"))
|
||||
except requests.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Could not fetch version information from GitHub. Please check your network connection and try again.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("An unexpected error occurred while trying to fetch version information from GitHub. Please try again.")
|
||||
return
|
||||
|
||||
releases.sort(reverse=True)
|
||||
pre_releases.sort(reverse=True)
|
||||
|
||||
return releases, pre_releases
|
||||
|
||||
|
||||
def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
This is only applicable to Windows and Linux, since PyTorch does not
|
||||
offer accelerated builds for macOS.
|
||||
|
||||
Prefer CUDA-enabled wheels if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
||||
|
||||
A NoneType return means just go to PyPi.
|
||||
|
||||
:return: tuple consisting of (extra index url or None, optional modules to load or None)
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import GpuType, select_gpu
|
||||
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
|
||||
|
||||
url = None
|
||||
optional_modules: str | None = None
|
||||
if OS == "Linux":
|
||||
if device == GpuType.ROCM:
|
||||
url = "https://download.pytorch.org/whl/rocm6.1"
|
||||
elif device == GpuType.CPU:
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
elif device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif OS == "Windows":
|
||||
if device == GpuType.CUDA:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[onnx-cuda]"
|
||||
elif device == GpuType.CUDA_WITH_XFORMERS:
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
elif device.value == "cpu":
|
||||
# CPU uses the default PyPi index, no optional modules
|
||||
pass
|
||||
elif OS == "Darwin":
|
||||
# macOS uses the default PyPi index, no optional modules
|
||||
pass
|
||||
|
||||
# Fall back to defaults
|
||||
|
||||
return (url, optional_modules)
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
InvokeAI Installer
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from installer import Installer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
dest="root",
|
||||
type=str,
|
||||
help="Destination path for installation",
|
||||
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-y",
|
||||
"--yes",
|
||||
"--yes-to-all",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help="Assume default answers to all questions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--find-links",
|
||||
dest="find_links",
|
||||
help="Specifies a directory of local wheel files to be searched prior to searching the online repositories.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wheel",
|
||||
dest="wheel",
|
||||
help="Specifies a wheel for the InvokeAI package. Used for troubleshooting or testing prereleases.",
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
@@ -1,342 +0,0 @@
|
||||
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
||||
"""
|
||||
Installer user interaction
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def welcome(available_releases: tuple[list[str], list[str]] | None = None) -> None:
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) is not None:
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
if available_releases is not None:
|
||||
latest_stable = available_releases[0][0]
|
||||
last_pre = available_releases[1][0]
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Latest stable release (recommended): [b bright_white]{latest_stable}", justify="center"
|
||||
)
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Last published pre-release version: [b bright_white]{last_pre}", justify="center"
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
Panel(
|
||||
title="[bold wheat1]Welcome to the InvokeAI Installer",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=Style(bgcolor="grey23", color="orange1"),
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def installing_from_wheel(wheel_filename: str) -> None:
|
||||
"""Display a message about installing from a wheel"""
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield Text.from_markup(f"You are installing from a wheel file: [bold]{wheel_filename}\n")
|
||||
yield Text.from_markup(
|
||||
"[bold orange3]If you are not sure why you are doing this, you should cancel and install InvokeAI normally."
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
title="Installing from Wheel",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
should_proceed = Confirm.ask("Do you want to proceed?")
|
||||
|
||||
if not should_proceed:
|
||||
console.print("Installation cancelled.")
|
||||
exit()
|
||||
|
||||
|
||||
def choose_version(available_releases: tuple[list[str], list[str]] | None = None) -> str:
|
||||
"""
|
||||
Prompt the user to choose an Invoke version to install
|
||||
"""
|
||||
|
||||
# short circuit if we couldn't get a version list
|
||||
# still try to install the latest stable version
|
||||
if available_releases is None:
|
||||
return "stable"
|
||||
|
||||
console.print(":grey_question: [orange3]Please choose an Invoke version to install.")
|
||||
|
||||
choices = available_releases[0] + available_releases[1]
|
||||
|
||||
response = prompt(
|
||||
message=f" <Enter> to install the recommended release ({choices[0]}). <Tab> or type to pick a version: ",
|
||||
complete_while_typing=True,
|
||||
completer=FuzzyWordCompleter(choices),
|
||||
)
|
||||
console.print(f" Version {choices[0] if response == '' else response} will be installed.")
|
||||
|
||||
console.line()
|
||||
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
print(" Is this location correct?")
|
||||
default = False
|
||||
else:
|
||||
print(f":file_folder: InvokeAI will be installed in {dest}")
|
||||
default = True
|
||||
|
||||
dest_confirmed = Confirm.ask(" Please confirm:", default=default)
|
||||
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def dest_path(dest: Optional[str | Path] = None) -> Path | None:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
:param dest: a filesystem path, defaults to None
|
||||
:type dest: str, optional
|
||||
:return: absolute path to the created installation directory
|
||||
:rtype: Path
|
||||
"""
|
||||
|
||||
if dest is not None:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
dest_confirmed = False
|
||||
|
||||
while not dest_confirmed:
|
||||
browse_start = (dest or Path.cwd()).expanduser().resolve()
|
||||
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [str(browse_start)], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
console.line()
|
||||
|
||||
console.print(f":grey_question: [orange3]Please select the install destination:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
vi_mode=True,
|
||||
complete_while_typing=True,
|
||||
# Test that this is not needed on Windows
|
||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||
)
|
||||
prev_dest = dest
|
||||
dest = Path(selected)
|
||||
|
||||
console.line()
|
||||
|
||||
dest_confirmed = confirm_install(dest.expanduser().resolve())
|
||||
|
||||
if not dest_confirmed:
|
||||
dest = prev_dest
|
||||
|
||||
dest = dest.expanduser().resolve()
|
||||
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
else:
|
||||
console.rule("Goodbye!")
|
||||
|
||||
|
||||
class GpuType(Enum):
|
||||
CUDA_WITH_XFORMERS = "xformers"
|
||||
CUDA = "cuda"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
"""
|
||||
Prompt the user to select the GPU driver
|
||||
"""
|
||||
|
||||
if ARCH == "arm64" and OS != "Darwin":
|
||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||
return GpuType.CPU
|
||||
|
||||
nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 3060 or newer GPU using CUDA",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
vintage_nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] RTX 20xx or older GPU using CUDA+xFormers",
|
||||
GpuType.CUDA_WITH_XFORMERS,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU using ROCm",
|
||||
GpuType.ROCM,
|
||||
)
|
||||
cpu = (
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
|
||||
options = []
|
||||
if OS == "Windows":
|
||||
options = [nvidia, vintage_nvidia, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, vintage_nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
options = [cpu]
|
||||
|
||||
if len(options) == 1:
|
||||
return options[0][1]
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
f"Detected the [gold1]{OS}-{ARCH}[/] platform",
|
||||
"",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/installation/requirements/[/] to ensure your system meets the minimum requirements.",
|
||||
"",
|
||||
"[red3]🠶[/] [b]Your GPU drivers must be correctly installed before using InvokeAI![/] [red3]🠴[/]",
|
||||
]
|
||||
),
|
||||
"",
|
||||
"Please select the type of GPU installed in your computer.",
|
||||
Panel(
|
||||
"\n".join([f"[dark_goldenrod b i]{i}[/] [dark_red]🢒[/]{opt[0]}" for (i, opt) in options.items()]),
|
||||
box=box.MINIMAL,
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = prompt(
|
||||
"Please make your selection: ",
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n in options.keys(), error_message="Please select one the above options"
|
||||
),
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
|
||||
|
||||
def simple_banner(message: str) -> None:
|
||||
"""
|
||||
A simple banner with a message, defined here for styling consistency
|
||||
|
||||
:param message: The message to display
|
||||
:type message: str
|
||||
"""
|
||||
|
||||
console.rule(message)
|
||||
|
||||
|
||||
# TODO this does not yet work correctly
|
||||
def windows_long_paths_registry() -> None:
|
||||
"""
|
||||
Display a message about applying the Windows long paths registry fix
|
||||
"""
|
||||
|
||||
with open(str(Path(__file__).parent / "WinLongPathsEnabled.reg"), "r", encoding="utf-16le") as code:
|
||||
syntax = Syntax(code.read(), line_numbers=True, lexer="regedit")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"We will now apply a registry fix to enable long paths on Windows. InvokeAI needs this to function correctly. We are asking your permission to modify the Windows Registry on your behalf.",
|
||||
"",
|
||||
"This is the change that will be applied:",
|
||||
str(syntax),
|
||||
]
|
||||
)
|
||||
),
|
||||
title="Windows Long Paths registry fix",
|
||||
box=box.HORIZONTALS,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
return
|
||||
return text
|
||||
@@ -1,52 +0,0 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Preparations:
|
||||
|
||||
You will need to install Python 3.10 or higher for this installer
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.10.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
pip --version
|
||||
|
||||
You should get a message that indicates that the pip package
|
||||
installer was derived from Python 3.10 or 3.11. For example:
|
||||
"pip 22.0.1 from /usr/bin/pip (python 3.10)"
|
||||
|
||||
Long Paths on Windows:
|
||||
|
||||
If you are on Windows, you will need to enable Windows Long Paths to
|
||||
run InvokeAI successfully. If you're not sure what this is, you
|
||||
almost certainly need to do this.
|
||||
|
||||
Simply double-click the "WinLongPathsEnabled.reg" file located in
|
||||
this directory, and approve the Windows warnings. Note that you will
|
||||
need to have admin privileges in order to do this.
|
||||
|
||||
Launching the installer:
|
||||
|
||||
Windows: double-click the 'install.bat' file (while keeping it inside
|
||||
the InvokeAI-Installer folder).
|
||||
|
||||
Linux and Mac: Please open the terminal application and run
|
||||
'./install.sh' (while keeping it inside the InvokeAI-Installer
|
||||
folder).
|
||||
|
||||
The installer will create a directory of your choice and install the
|
||||
InvokeAI application within it. This directory contains everything you need to run
|
||||
invokeai. Once InvokeAI is up and running, you may delete the
|
||||
InvokeAI-Installer folder at your convenience.
|
||||
|
||||
For more information, please see
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
@@ -1,54 +0,0 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
setlocal
|
||||
|
||||
call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Open the developer console
|
||||
echo 3. Command-line help
|
||||
echo Q - Quit
|
||||
echo.
|
||||
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
echo.
|
||||
set /P choice="Please enter 1-4, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
goto start
|
||||
|
||||
endlocal
|
||||
pause
|
||||
|
||||
:ending
|
||||
exit /b
|
||||
@@ -1,87 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
# 2. the .venv is also located in the runtime directory and is named exactly that
|
||||
#
|
||||
# If both of the above are not true, this script will likely not work as intended.
|
||||
# Activate the virtual environment and run `invoke.py` directly.
|
||||
####
|
||||
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
|
||||
# Stash the CLI args - when we prompt for user input, `$@` is overwritten
|
||||
PARAMS=$@
|
||||
|
||||
# This setting allows torch to fall back to CPU for operations that are not supported by MPS on macOS.
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest\n\n"
|
||||
read -p "Please enter 1-4, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke, and for checking if the user is in the developer console
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
do_line_input
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
printf "Press ^D to exit\n"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
@@ -37,7 +37,13 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -101,10 +107,25 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
ObjectSerializerDisk[torch.Tensor](
|
||||
output_folder / "tensors",
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
max_cache_size=0,
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
output_folder / "conditioning",
|
||||
safe_globals=[
|
||||
ConditioningFieldData,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
@@ -34,6 +35,12 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class ValidationRunData(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow being published.")
|
||||
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
|
||||
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -45,6 +52,10 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
validation_run_data: Optional[ValidationRunData] = Body(
|
||||
default=None,
|
||||
description="The validation run data to use for this batch. This is only used if this is a validation run.",
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
@@ -118,6 +119,7 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
|
||||
128
invokeai/app/invocations/controlnet.py
Normal file
128
invokeai/app/invocations/controlnet.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,716 +0,0 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (
|
||||
ContentShuffleDetector,
|
||||
LeresDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_image_processor",
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_image_processor",
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_image_processor",
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_image_processor",
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"normalbae_image_processor",
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle_image_processor",
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@invocation(
|
||||
"zoe_depth_image_processor",
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_processor",
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
max_faces=self.max_faces,
|
||||
min_confidence=self.min_confidence,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"leres_image_processor",
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"tile_image_processor",
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
H, W, C = np_img.shape
|
||||
H = int(float(H) / float(down_sampling_rate))
|
||||
W = int(float(W) / float(down_sampling_rate))
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_processor",
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_image_processor",
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
|
||||
@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
|
||||
@@ -302,7 +302,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||
# newer version of the app that added a new model type.
|
||||
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
|
||||
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
|
||||
self._logger.warning(
|
||||
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
|
||||
)
|
||||
else:
|
||||
results.append(model_config)
|
||||
|
||||
|
||||
@@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param safe_globals: A list of types to be added to the safe globals for torch serialization
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Path,
|
||||
safe_globals: list[type],
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
@@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
|
||||
@@ -201,6 +201,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
return None
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@@ -237,6 +243,20 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
retried_from_item_id: Optional[int] = Field(
|
||||
default=None, description="The item_id of the queue item that this item was retried from"
|
||||
)
|
||||
is_api_validation_run: bool = Field(
|
||||
default=False,
|
||||
description="Whether this queue item is an API validation run.",
|
||||
)
|
||||
published_workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
|
||||
@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
query: Optional[str],
|
||||
tags: Optional[list[str]],
|
||||
has_been_opened: Optional[bool],
|
||||
is_published: Optional[bool],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided categories."""
|
||||
pass
|
||||
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided tags."""
|
||||
pass
|
||||
|
||||
@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
|
||||
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
|
||||
# it is None.
|
||||
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
|
||||
opened_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The opened timestamp of the workflow."
|
||||
)
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
|
||||
@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
query: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
|
||||
@@ -65,9 +65,6 @@ def apply_monkeypatches() -> None:
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
def register_mime_types() -> None:
|
||||
"""Register additional mime types for windows."""
|
||||
|
||||
@@ -5,62 +5,14 @@ import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
from invokeai.backend.image_util.util import np_to_pil
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
dwpose_image = Image.fromarray(dwpose_image)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
|
||||
class DWOpenposeDetector:
|
||||
"""
|
||||
@@ -68,62 +20,6 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
class DWOpenposeDetector2:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
|
||||
and managed by the model manager.
|
||||
"""
|
||||
|
||||
hf_repo_id = "yzd-v/DWPose"
|
||||
hf_filename_onnx_det = "yolox_l.onnx"
|
||||
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
|
||||
@@ -213,7 +109,7 @@ class DWOpenposeDetector2:
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return DWOpenposeDetector2.draw_pose(
|
||||
return DWOpenposeDetector.draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@@ -127,11 +126,13 @@ def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
hsv_color = np.array([[[ie / float(len(edges)) * 180, 255, 255]]], dtype=np.uint8)
|
||||
rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0, 0]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
rgb_color.tolist(),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, oriImg):
|
||||
det_result = inference_detector(self.session_det, oriImg)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
@@ -69,6 +69,9 @@ class SD3ConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
|
||||
# invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be
|
||||
# unable to deserialize the object and will raise an error.
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
import math
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
|
||||
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
|
||||
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
|
||||
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
|
||||
|
||||
def new_torch_interpolate(
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(
|
||||
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
||||
# TODO: refactor it
|
||||
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
|
||||
class ChunkedSlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
assert isinstance(slice_size, int)
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
self.slice_size = slice_size
|
||||
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
if self.slice_size != 1 or attn.upcast_attention:
|
||||
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
chunk_tmp_tensor = torch.empty(
|
||||
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
self.get_attention_scores_chunked(
|
||||
attn,
|
||||
query_slice,
|
||||
key_slice,
|
||||
attn_mask_slice,
|
||||
hidden_states[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
chunk_tmp_tensor,
|
||||
)
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
||||
# batch size = 1
|
||||
assert query.shape[0] == 1
|
||||
assert key.shape[0] == 1
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
# out_item_size = query.dtype.itemsize
|
||||
# if attn.upcast_attention:
|
||||
# out_item_size = torch.float32.itemsize
|
||||
out_item_size = query.element_size()
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2**29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
chunk_step = max(1, int(query.shape[1] / chunks_count))
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
def _get_chunk_view(tensor, start, length):
|
||||
if start + length > tensor.shape[1]:
|
||||
length = tensor.shape[1] - start
|
||||
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:, start : start + length]
|
||||
|
||||
for chunk_pos in range(0, query.shape[1], chunk_step):
|
||||
if attention_mask is not None:
|
||||
torch.baddbmm(
|
||||
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=1,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
else:
|
||||
torch.baddbmm(
|
||||
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
chunk = chunk.softmax(dim=-1)
|
||||
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
||||
|
||||
# del chunk
|
||||
|
||||
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|
||||
@@ -62,7 +62,7 @@
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.6.1",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.4.2",
|
||||
"@xyflow/react": "^12.5.3",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.0.0",
|
||||
@@ -150,7 +150,7 @@
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"tsafe": "^1.7.5",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^6.1.0",
|
||||
@@ -162,5 +162,6 @@
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
}
|
||||
},
|
||||
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
|
||||
}
|
||||
|
||||
56
invokeai/frontend/web/pnpm-lock.yaml
generated
56
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -36,8 +36,8 @@ dependencies:
|
||||
specifier: ^1.3.0
|
||||
version: 1.3.0
|
||||
'@xyflow/react':
|
||||
specifier: ^12.4.2
|
||||
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^12.5.3
|
||||
version: 12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
async-mutex:
|
||||
specifier: ^0.5.0
|
||||
version: 0.5.0
|
||||
@@ -284,8 +284,8 @@ devDependencies:
|
||||
specifier: ^8.3.4
|
||||
version: 8.3.4
|
||||
tsafe:
|
||||
specifier: ^1.7.5
|
||||
version: 1.7.5
|
||||
specifier: ^1.8.5
|
||||
version: 1.8.5
|
||||
type-fest:
|
||||
specifier: ^4.26.1
|
||||
version: 4.26.1
|
||||
@@ -3323,7 +3323,7 @@ packages:
|
||||
/@types/d3-drag@3.0.7:
|
||||
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-interpolate@3.0.4:
|
||||
@@ -3332,21 +3332,21 @@ packages:
|
||||
'@types/d3-color': 3.1.3
|
||||
dev: false
|
||||
|
||||
/@types/d3-selection@3.0.10:
|
||||
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
|
||||
/@types/d3-selection@3.0.11:
|
||||
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
|
||||
dev: false
|
||||
|
||||
/@types/d3-transition@3.0.8:
|
||||
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
|
||||
/@types/d3-transition@3.0.9:
|
||||
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-zoom@3.0.8:
|
||||
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
|
||||
dependencies:
|
||||
'@types/d3-interpolate': 3.0.4
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/diff-match-patch@1.0.36:
|
||||
@@ -3951,28 +3951,28 @@ packages:
|
||||
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
|
||||
dev: false
|
||||
|
||||
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
|
||||
/@xyflow/react@12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-saovy/aQRoW8qQoIqMFUtmC3F6oEV7n6+J1pVbhSG45NI/hOFvK0qozsIPKqX5Va6lGQnkl/o53NHLja3NiweQ==}
|
||||
peerDependencies:
|
||||
react: '>=17'
|
||||
react-dom: '>=17'
|
||||
dependencies:
|
||||
'@xyflow/system': 0.0.50
|
||||
'@xyflow/system': 0.0.53
|
||||
classcat: 5.0.5
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
|
||||
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- immer
|
||||
dev: false
|
||||
|
||||
/@xyflow/system@0.0.50:
|
||||
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
|
||||
/@xyflow/system@0.0.53:
|
||||
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
|
||||
dependencies:
|
||||
'@types/d3-drag': 3.0.7
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-transition': 3.0.8
|
||||
'@types/d3-selection': 3.0.11
|
||||
'@types/d3-transition': 3.0.9
|
||||
'@types/d3-zoom': 3.0.8
|
||||
d3-drag: 3.0.0
|
||||
d3-selection: 3.0.0
|
||||
@@ -8791,8 +8791,8 @@ packages:
|
||||
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
|
||||
dev: false
|
||||
|
||||
/tsafe@1.7.5:
|
||||
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
|
||||
/tsafe@1.8.5:
|
||||
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
|
||||
dev: true
|
||||
|
||||
/tsconfck@3.1.5(typescript@5.6.2):
|
||||
@@ -9123,6 +9123,14 @@ packages:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/use-sync-external-store@1.5.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/util-deprecate@1.0.2:
|
||||
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
||||
dev: true
|
||||
@@ -9567,8 +9575,8 @@ packages:
|
||||
/zod@3.23.8:
|
||||
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
|
||||
|
||||
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
|
||||
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
|
||||
engines: {node: '>=12.7.0'}
|
||||
peerDependencies:
|
||||
'@types/react': '>=16.8'
|
||||
@@ -9584,5 +9592,5 @@ packages:
|
||||
dependencies:
|
||||
'@types/react': 18.3.11
|
||||
react: 18.3.1
|
||||
use-sync-external-store: 1.2.2(react@18.3.1)
|
||||
use-sync-external-store: 1.5.0(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
@@ -116,7 +116,10 @@
|
||||
"combinatorial": "Kombinatorisch",
|
||||
"saveChanges": "Änderungen speichern",
|
||||
"error_withCount_one": "{{count}} Fehler",
|
||||
"error_withCount_other": "{{count}} Fehler"
|
||||
"error_withCount_other": "{{count}} Fehler",
|
||||
"value": "Wert",
|
||||
"label": "Label",
|
||||
"systemInformation": "Systeminformationen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -695,7 +698,10 @@
|
||||
"guidance": "Führung",
|
||||
"coherenceMode": "Modus",
|
||||
"recallMetadata": "Metadaten abrufen",
|
||||
"gaussianBlur": "Gaußsche Unschärfe"
|
||||
"gaussianBlur": "Gaußsche Unschärfe",
|
||||
"sendToUpscale": "An Hochskalieren senden",
|
||||
"useCpuNoise": "CPU-Rauschen verwenden",
|
||||
"sendToCanvas": "An Leinwand senden"
|
||||
},
|
||||
"settings": {
|
||||
"displayInProgress": "Zwischenbilder anzeigen",
|
||||
@@ -1328,7 +1334,8 @@
|
||||
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
|
||||
"loadingTemplates": "Lade {{name}}",
|
||||
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten",
|
||||
"showEdgeLabelsHelp": "Beschriftungen an Kanten anzeigen, um die verknüpften Knoten zu kennzeichnen"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -1706,6 +1706,7 @@
|
||||
"noRecentWorkflows": "No Recent Workflows",
|
||||
"private": "Private",
|
||||
"shared": "Shared",
|
||||
"published": "Published",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
@@ -1783,7 +1784,39 @@
|
||||
"textPlaceholder": "Empty Text",
|
||||
"workflowBuilderAlphaWarning": "The workflow builder is currently in alpha. There may be breaking changes before the stable release.",
|
||||
"minimum": "Minimum",
|
||||
"maximum": "Maximum"
|
||||
"maximum": "Maximum",
|
||||
"publish": "Publish",
|
||||
"published": "Published",
|
||||
"unpublish": "Unpublish",
|
||||
"workflowLocked": "Workflow Locked",
|
||||
"workflowLockedPublished": "Published workflows are locked for editing.\nYou can unpublish the workflow to edit it, or make a copy of it.",
|
||||
"workflowLockedDuringPublishing": "Workflow is locked while configuring for publishing.",
|
||||
"selectOutputNode": "Select Output Node",
|
||||
"changeOutputNode": "Change Output Node",
|
||||
"publishedWorkflowOutputs": "Outputs",
|
||||
"publishedWorkflowInputs": "Inputs",
|
||||
"unpublishableInputs": "These unpublishable inputs will be omitted",
|
||||
"noPublishableInputs": "No publishable inputs",
|
||||
"noOutputNodeSelected": "No output node selected",
|
||||
"cannotPublish": "Cannot publish workflow",
|
||||
"publishWarnings": "Warnings",
|
||||
"errorWorkflowHasUnsavedChanges": "Workflow has unsaved changes",
|
||||
"errorWorkflowHasBatchOrGeneratorNodes": "Workflow has batch and/or generator nodes",
|
||||
"errorWorkflowHasInvalidGraph": "Workflow graph invalid (hover Invoke button for details)",
|
||||
"errorWorkflowHasNoOutputNode": "No output node selected",
|
||||
"warningWorkflowHasNoPublishableInputFields": "No publishable input fields selected - published workflow will run with only default values",
|
||||
"warningWorkflowHasUnpublishableInputFields": "Workflow has some unpublishable inputs - these will be omitted from the published workflow",
|
||||
"publishFailed": "Publish failed",
|
||||
"publishFailedDesc": "There was a problem publishing the workflow. Please try again.",
|
||||
"publishSuccess": "Your workflow is being published",
|
||||
"publishSuccessDesc": "Check your <LinkComponent>Project Dashboard</LinkComponent> to see its progress.",
|
||||
"publishInProgress": "Publishing in progress",
|
||||
"publishedWorkflowIsLocked": "Published workflow is locked",
|
||||
"publishingValidationRun": "Publishing Validation Run",
|
||||
"publishingValidationRunInProgress": "Publishing validation run in progress.",
|
||||
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow.",
|
||||
"selectingOutputNode": "Selecting output node",
|
||||
"selectingOutputNodeDesc": "Click a node to select it as the workflow's output node."
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
|
||||
@@ -115,7 +115,8 @@
|
||||
"error_withCount_many": "{{count}} errori",
|
||||
"error_withCount_other": "{{count}} errori",
|
||||
"value": "Valore",
|
||||
"label": "Etichetta"
|
||||
"label": "Etichetta",
|
||||
"systemInformation": "Informazioni di sistema"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -715,7 +716,8 @@
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
|
||||
"collectionEmpty": "raccolta vuota"
|
||||
"collectionEmpty": "raccolta vuota",
|
||||
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -2365,8 +2367,9 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
|
||||
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela.",
|
||||
"LLaVA OneVision VLLM: supporto beta nei flussi di lavoro."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -237,7 +237,10 @@
|
||||
"row": "Hàng",
|
||||
"board": "Bảng",
|
||||
"saveChanges": "Lưu Thay Đổi",
|
||||
"error_withCount_other": "{{count}} lỗi"
|
||||
"error_withCount_other": "{{count}} lỗi",
|
||||
"value": "Giá Trị",
|
||||
"label": "Nhãn Tên",
|
||||
"systemInformation": "Thông Tin Hệ Thống"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -2300,7 +2303,10 @@
|
||||
"minimum": "Tối Thiểu",
|
||||
"maximum": "Tối Đa",
|
||||
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)",
|
||||
"resetOptions": "Tải Lại Lựa Chọn",
|
||||
"addOption": "Thêm Lựa Chọn",
|
||||
"dropdown": "Danh Sách Thả Xuống"
|
||||
},
|
||||
"yourWorkflows": "Workflow Của Bạn",
|
||||
"browseWorkflows": "Khám Phá Workflow",
|
||||
@@ -2316,7 +2322,8 @@
|
||||
"view": "Xem",
|
||||
"deselectAll": "Huỷ Chọn Tất Cả",
|
||||
"noRecentWorkflows": "Không Có Workflows Gần Đây",
|
||||
"recommended": "Có Thể Bạn Sẽ Cần"
|
||||
"recommended": "Có Thể Bạn Sẽ Cần",
|
||||
"emptyStringPlaceholder": "<xâu ký tự trống>"
|
||||
},
|
||||
"upscaling": {
|
||||
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
|
||||
@@ -2352,8 +2359,9 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
|
||||
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
|
||||
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
|
||||
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas.",
|
||||
"LLaVA OneVision VLLM: Hỗ trợ phiên bản Beta trong Workflow."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
|
||||
export const enqueueRequested = createAction<{
|
||||
tabName: TabName;
|
||||
prepend: boolean;
|
||||
}>('app/enqueueRequested');
|
||||
@@ -10,7 +10,6 @@ import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/l
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
@@ -63,7 +62,6 @@ addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedNodes(startAppListening);
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
|
||||
@@ -5,7 +5,7 @@ import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAd
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import type { EnqueueBatchArg, ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('queue');
|
||||
@@ -19,7 +19,7 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
||||
const { imageDTO } = action.payload;
|
||||
const state = getState();
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
const enqueueBatchArg: EnqueueBatchArg = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: await buildAdHocPostProcessingGraph({
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
@@ -17,10 +17,11 @@ import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
const state = getState();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
@@ -9,10 +9,11 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
actionCreator: enqueueRequestedUpscaling,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
.concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
|
||||
@@ -74,6 +74,7 @@ export type AppConfig = {
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -14,7 +14,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invoke',
|
||||
category: 'app',
|
||||
callback: queue.queueBack,
|
||||
callback: queue.enqueueBack,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
@@ -26,7 +26,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invokeFront',
|
||||
category: 'app',
|
||||
callback: queue.queueFront,
|
||||
callback: queue.enqueueFront,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
|
||||
@@ -54,7 +54,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
|
||||
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -291,7 +291,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
*/
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: BatchConfig = {
|
||||
const batch: EnqueueBatchArg = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: graph.getGraph(),
|
||||
|
||||
@@ -49,7 +49,11 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavLeft',
|
||||
category: 'gallery',
|
||||
callback: () => {
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev('arrow');
|
||||
return;
|
||||
@@ -71,7 +75,11 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavRight',
|
||||
category: 'gallery',
|
||||
callback: () => {
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext('arrow');
|
||||
return;
|
||||
|
||||
@@ -2,7 +2,9 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import { TopCenterPanel } from 'features/nodes/components/flow/panels/TopPanel/TopCenterPanel';
|
||||
import { TopLeftPanel } from 'features/nodes/components/flow/panels/TopPanel/TopLeftPanel';
|
||||
import { TopRightPanel } from 'features/nodes/components/flow/panels/TopPanel/TopRightPanel';
|
||||
import WorkflowEditorSettings from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -32,7 +34,9 @@ const NodeEditor = () => {
|
||||
<>
|
||||
<Flow />
|
||||
<AddNodeCmdk />
|
||||
<TopPanel />
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<BottomLeftPanel />
|
||||
<MinimapPanel />
|
||||
</>
|
||||
|
||||
@@ -18,6 +18,7 @@ import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$cursorPos,
|
||||
@@ -146,6 +147,7 @@ export const AddNodeCmdk = memo(() => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const addNode = useAddNode();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
// Filtering the list is expensive - debounce the search term to avoid stutters
|
||||
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
|
||||
const isOpen = useStore($addNodeCmdk);
|
||||
@@ -160,8 +162,8 @@ export const AddNodeCmdk = memo(() => {
|
||||
id: 'addNode',
|
||||
category: 'workflows',
|
||||
callback: open,
|
||||
options: { enabled: tab === 'workflows', preventDefault: true },
|
||||
dependencies: [open, tab],
|
||||
options: { enabled: tab === 'workflows' && !isLocked, preventDefault: true },
|
||||
dependencies: [open, tab, isLocked],
|
||||
});
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -4,6 +4,7 @@ import type {
|
||||
EdgeChange,
|
||||
HandleType,
|
||||
NodeChange,
|
||||
NodeMouseHandler,
|
||||
OnEdgesChange,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
@@ -16,8 +17,10 @@ import type {
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useNodeCopyPaste } from 'features/nodes/hooks/useNodeCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import {
|
||||
@@ -44,7 +47,7 @@ import {
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { type AnyEdge, type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -92,6 +95,8 @@ export const Flow = memo(() => {
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const store = useAppStore();
|
||||
const isWorkflowsFocused = useIsRegionFocused('workflows');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
useFocusRegion('workflows', flowWrapper);
|
||||
|
||||
useSyncExecutionState();
|
||||
@@ -215,7 +220,7 @@ export const Flow = memo(() => {
|
||||
id: 'copySelection',
|
||||
category: 'workflows',
|
||||
callback: copySelection,
|
||||
options: { preventDefault: true },
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [copySelection],
|
||||
});
|
||||
|
||||
@@ -244,24 +249,24 @@ export const Flow = memo(() => {
|
||||
id: 'selectAll',
|
||||
category: 'workflows',
|
||||
callback: selectAll,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused, isLocked],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelection',
|
||||
category: 'workflows',
|
||||
callback: pasteSelection,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelection],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelection, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelectionWithEdges',
|
||||
category: 'workflows',
|
||||
callback: pasteSelectionWithEdges,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -270,8 +275,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(undo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -280,8 +285,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(redo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
@@ -318,10 +323,22 @@ export const Flow = memo(() => {
|
||||
id: 'deleteSelection',
|
||||
category: 'workflows',
|
||||
callback: deleteSelection,
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused],
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused && !isLocked },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused, isLocked],
|
||||
});
|
||||
|
||||
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
|
||||
if (!$isSelectingOutputNode.get()) {
|
||||
return;
|
||||
}
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const { id } = node.data;
|
||||
$outputNodeId.set(id);
|
||||
$isSelectingOutputNode.set(false);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ReactFlow<AnyNode, AnyEdge>
|
||||
id="workflow-editor"
|
||||
@@ -332,6 +349,7 @@ export const Flow = memo(() => {
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodeClick={onNodeClick}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
@@ -344,6 +362,12 @@ export const Flow = memo(() => {
|
||||
onMoveEnd={handleMoveEnd}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
isValidConnection={isValidConnection}
|
||||
edgesFocusable={!isLocked}
|
||||
edgesReconnectable={!isLocked}
|
||||
nodesDraggable={!isLocked}
|
||||
nodesConnectable={!isLocked}
|
||||
nodesFocusable={!isLocked}
|
||||
elementsSelectable={!isLocked}
|
||||
minZoom={0.1}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
snapGrid={snapGrid}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { map } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
@@ -19,7 +19,7 @@ const collapsedHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useInvocationNodeNotes } from 'features/nodes/hooks/useNodeNotes';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -27,9 +27,9 @@ InvocationNodeInfoIcon.displayName = 'InvocationNodeInfoIcon';
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const notes = useInvocationNodeNotes(nodeId);
|
||||
const label = useNodeLabel(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const version = useNodeVersion(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const nodeTemplate = useNodeTemplateOrThrow(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ChangeEvent } from 'react';
|
||||
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
|
||||
const Content = memo(({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(nodeId, fieldName);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));
|
||||
|
||||
@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
|
||||
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useRef } from 'react';
|
||||
@@ -100,7 +100,7 @@ const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemp
|
||||
const draggableRef = useRef<HTMLDivElement>(null);
|
||||
const dragHandleRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
|
||||
const isDragging = useNodeFieldDnd(nodeId, fieldName, fieldTemplate, draggableRef, dragHandleRef);
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
@@ -105,9 +106,16 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -130,6 +138,7 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionError !== null) {
|
||||
@@ -140,7 +149,13 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -17,7 +17,7 @@ import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invoca
|
||||
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
|
||||
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
|
||||
@@ -9,8 +9,8 @@ import {
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
@@ -43,8 +43,8 @@ interface Props {
|
||||
export const InputFieldTitle = memo((props: Props) => {
|
||||
const { nodeId, fieldName, isInvalid, isDragging } = props;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const label = useInputFieldLabelSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
|
||||
const label = useInputFieldUserTitleSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
@@ -105,9 +106,17 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -130,6 +139,7 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionErrorTKey !== null) {
|
||||
@@ -140,7 +150,13 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -3,8 +3,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -17,10 +17,10 @@ type Props = {
|
||||
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { ChakraProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
@@ -62,6 +63,12 @@ const containerSx: SystemStyleObject = {
|
||||
display: 'block',
|
||||
shadow: '0 0 0 3px var(--invoke-colors-blue-300)',
|
||||
},
|
||||
'&[data-is-editor-locked="true"]': {
|
||||
'& *': {
|
||||
cursor: 'not-allowed',
|
||||
pointerEvents: 'none',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const shadowsSx: SystemStyleObject = {
|
||||
@@ -98,7 +105,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const zoomToNode = useZoomToNode();
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const executionState = useNodeExecutionState(nodeId);
|
||||
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||
@@ -126,9 +134,9 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
// This target is marked as not fitting the view on double click
|
||||
return;
|
||||
}
|
||||
zoomToNode(nodeId);
|
||||
zoomToNode();
|
||||
},
|
||||
[nodeId, zoomToNode]
|
||||
[zoomToNode]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -141,6 +149,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
sx={containerSx}
|
||||
width={width || NODE_WIDTH}
|
||||
opacity={opacity}
|
||||
data-is-editor-locked={isLocked}
|
||||
data-is-selected={selected}
|
||||
data-is-mouse-over-form-field={mouseOverFormField.isMouseOverFormField}
|
||||
>
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const TopCenterPanel = memo(() => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={2} left="50%" transform="translateX(-50%)" position="absolute" pointerEvents="none">
|
||||
{!!name.length && <WorkflowName />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TopCenterPanel.displayName = 'TopCenterPanel';
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Alert, AlertDescription, AlertIcon, AlertTitle, Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isSelectingOutputNode,
|
||||
useIsValidationRunInProgress,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const TopLeftPanel = memo(() => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isValidationRunInProgress = useIsValidationRunInProgress();
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
{!isLocked && (
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
)}
|
||||
{isLocked && (
|
||||
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<Box>
|
||||
<AlertTitle>{t('workflows.builder.workflowLocked')}</AlertTitle>
|
||||
{isValidationRunInProgress && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.publishingValidationRunInProgress')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && !isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedDuringPublishing')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.selectingOutputNodeDesc')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isPublished && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedPublished')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
</Box>
|
||||
</Alert>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopLeftPanel.displayName = 'TopLeftPanel';
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
<Spacer />
|
||||
{!!name.length && <WorkflowName />}
|
||||
<Spacer />
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopCenterPanel);
|
||||
@@ -0,0 +1,34 @@
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
export const TopRightPanel = memo(() => {
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (isLocked) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2} top={2} right={2} position="absolute" alignItems="flex-end" pointerEvents="none">
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopRightPanel.displayName = 'TopRightPanel';
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { HorizontalResizeHandle } from 'features/ui/components/tabs/ResizeHandle';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -23,23 +22,21 @@ export const EditModeLeftPanelContent = memo(() => {
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</ScrollableContent>
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold, PiLockOpenBold } from 'react-icons/pi';
|
||||
|
||||
export const PublishedWorkflowPanelContent = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveAs = useSaveOrSaveAsWorkflow();
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" gap={2} alignItems="center">
|
||||
<Heading size="md" pt={32}>
|
||||
{t('workflows.builder.workflowLocked')}
|
||||
</Heading>
|
||||
<Text fontSize="md">{t('workflows.builder.publishedWorkflowsLocked')}</Text>
|
||||
<Button size="md" onClick={saveAs} variant="ghost" leftIcon={<PiCopyBold />}>
|
||||
{t('common.saveAs')}
|
||||
</Button>
|
||||
<Button size="md" onClick={undefined} variant="ghost" leftIcon={<PiLockOpenBold />}>
|
||||
{t('workflows.builder.unpublish')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishedWorkflowPanelContent.displayName = 'PublishedWorkflowPanelContent';
|
||||
@@ -2,7 +2,7 @@ import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowListMenuTrigger } from 'features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger';
|
||||
import { WorkflowViewEditToggleButton } from 'features/nodes/components/sidePanel/WorkflowViewEditToggleButton';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { WorkflowLibraryMenu } from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -10,12 +10,13 @@ import SaveWorkflowButton from './SaveWorkflowButton';
|
||||
|
||||
export const ActiveWorkflowNameAndActions = memo(() => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={1} minW={0}>
|
||||
<WorkflowListMenuTrigger />
|
||||
<Spacer />
|
||||
{mode === 'edit' && <SaveWorkflowButton />}
|
||||
{mode === 'edit' && !isPublished && <SaveWorkflowButton />}
|
||||
<WorkflowViewEditToggleButton />
|
||||
<WorkflowLibraryMenu />
|
||||
</Flex>
|
||||
|
||||
@@ -1,22 +1,30 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EditModeLeftPanelContent } from 'features/nodes/components/sidePanel/EditModeLeftPanelContent';
|
||||
import { PublishedWorkflowPanelContent } from 'features/nodes/components/sidePanel/PublishedWorkflowPanelContent';
|
||||
import { $isInPublishFlow } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { PublishWorkflowPanelContent } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { ActiveWorkflowDescription } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowDescription';
|
||||
import { ActiveWorkflowNameAndActions } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowNameAndActions';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { ViewModeLeftPanelContent } from './viewMode/ViewModeLeftPanelContent';
|
||||
|
||||
const WorkflowsTabLeftPanel = () => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" gap={2} flexDir="column">
|
||||
<ActiveWorkflowNameAndActions />
|
||||
{mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
{isInPublishFlow && <PublishWorkflowPanelContent />}
|
||||
{!isInPublishFlow && <ActiveWorkflowNameAndActions />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
{isPublished && <PublishedWorkflowPanelContent />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -67,11 +67,8 @@ FormElementEditModeHeader.displayName = 'FormElementEditModeHeader';
|
||||
const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const { t } = useTranslation();
|
||||
const { nodeId } = element.data.fieldIdentifier;
|
||||
const zoomToNode = useZoomToNode();
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const onClick = useCallback(() => {
|
||||
zoomToNode(nodeId);
|
||||
}, [nodeId, zoomToNode]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -79,7 +76,7 @@ const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
tooltip={t('workflows.builder.zoomToNode')}
|
||||
aria-label={t('workflows.builder.zoomToNode')}
|
||||
onClick={onClick}
|
||||
onClick={zoomToNode}
|
||||
icon={<PiGpsFixBold />}
|
||||
variant="link"
|
||||
size="sm"
|
||||
|
||||
@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -13,7 +13,7 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
|
||||
return (
|
||||
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
|
||||
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
|
||||
<NodeFieldElementOverlay element={el} />
|
||||
<NodeFieldElementOverlay nodeId={el.data.fieldIdentifier.nodeId} />
|
||||
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
|
||||
</Flex>
|
||||
);
|
||||
@@ -105,9 +105,9 @@ const nodeFieldOverlaySx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const NodeFieldElementOverlay = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const mouseOverNode = useMouseOverNode(element.data.fieldIdentifier.nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(element.data.fieldIdentifier.nodeId);
|
||||
export const NodeFieldElementOverlay = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
|
||||
return (
|
||||
<Box
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -12,7 +12,7 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
|
||||
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
|
||||
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import {
|
||||
isFloatFieldInputTemplate,
|
||||
|
||||
@@ -5,8 +5,9 @@ import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplateSafe';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -36,7 +37,7 @@ const useFormatFallbackLabel = () => {
|
||||
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const containerCtx = useContainerContext();
|
||||
const formatFallbackLabel = useFormatFallbackLabel();
|
||||
@@ -69,7 +70,7 @@ NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
|
||||
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _description = useMemo(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import type { DropTargetRecord } from '@atlaskit/pragmatic-drag-and-drop/dist/types/internal-types';
|
||||
import type { ElementDragPayload } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import {
|
||||
draggable,
|
||||
dropTargetForElements,
|
||||
@@ -33,7 +35,7 @@ import {
|
||||
selectFormRootElementId,
|
||||
selectWorkflowSlice,
|
||||
} from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldIdentifier, FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { ElementId, FormElement } from 'features/nodes/types/workflow';
|
||||
import { buildNodeFieldElement, isContainerElement } from 'features/nodes/types/workflow';
|
||||
import type { RefObject } from 'react';
|
||||
@@ -58,6 +60,27 @@ const isFormElementDndData = (data: Record<string | symbol, unknown>): data is F
|
||||
return uniqueFormElementDndKey in data;
|
||||
};
|
||||
|
||||
const uniqueNodeFieldDndKey = Symbol('node-field');
|
||||
type NodeFieldDndData = {
|
||||
[uniqueNodeFieldDndKey]: true;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
const buildNodeFieldDndData = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate
|
||||
): NodeFieldDndData => ({
|
||||
[uniqueNodeFieldDndKey]: true,
|
||||
nodeId,
|
||||
fieldName,
|
||||
fieldTemplate,
|
||||
});
|
||||
const isNodeFieldDndData = (data: Record<string | symbol, unknown>): data is NodeFieldDndData => {
|
||||
return uniqueNodeFieldDndKey in data;
|
||||
};
|
||||
|
||||
/**
|
||||
* Flashes an element by changing its background color. Used to indicate that an element has been moved.
|
||||
* @param elementId The id of the element to flash
|
||||
@@ -133,6 +156,27 @@ const useGetInitialValue = () => {
|
||||
return _getInitialValue;
|
||||
};
|
||||
|
||||
const getSourceElement = (source: ElementDragPayload) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
const { nodeId, fieldName, fieldTemplate } = source.data;
|
||||
return buildNodeFieldElement(nodeId, fieldName, fieldTemplate.type);
|
||||
}
|
||||
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const getTargetElement = (target: DropTargetRecord) => {
|
||||
if (isFormElementDndData(target.data)) {
|
||||
return target.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Singleton hook that monitors for builder dnd events and dispatches actions accordingly.
|
||||
*/
|
||||
@@ -156,20 +200,20 @@ export const useBuilderDndMonitor = () => {
|
||||
|
||||
useEffect(() => {
|
||||
return monitorForElements({
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data),
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data) || isNodeFieldDndData(source.data),
|
||||
onDrop: ({ location, source }) => {
|
||||
const target = location.current.dropTargets[0];
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isFormElementDndData(source.data) || !isFormElementDndData(target.data)) {
|
||||
const sourceElement = getSourceElement(source);
|
||||
const targetElement = getTargetElement(target);
|
||||
|
||||
if (!sourceElement || !targetElement) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceElement = source.data.element;
|
||||
const targetElement = target.data.element;
|
||||
|
||||
if (sourceElement.id === targetElement.id) {
|
||||
// Dropping on self is a no-op
|
||||
return;
|
||||
@@ -359,8 +403,15 @@ export const useFormElementDnd = (
|
||||
element: draggableElement,
|
||||
// TODO(psyche): This causes a kinda jittery behaviour - need a better heuristic to determine stickiness
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) =>
|
||||
isFormElementDndData(source.data) && source.data.element.id !== getElement(elementId).parentId,
|
||||
canDrop: ({ source }) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element.id !== getElement(elementId).parentId;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(elementId);
|
||||
|
||||
@@ -423,8 +474,16 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
dropTargetForElements({
|
||||
element: droppableElement,
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) =>
|
||||
getElement(rootElementId, isContainerElement).data.children.length === 0 && isFormElementDndData(source.data),
|
||||
canDrop: ({ source }) => {
|
||||
const rootElement = getElement(rootElementId, isContainerElement);
|
||||
if (rootElement.data.children.length !== 0) {
|
||||
return false;
|
||||
}
|
||||
if (isNodeFieldDndData(source.data) || isFormElementDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(rootElementId, isContainerElement);
|
||||
|
||||
@@ -455,7 +514,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
/**
|
||||
* Hook that provides dnd functionality for node fields.
|
||||
*
|
||||
* @param fieldIdentifier The identifier of the node field
|
||||
* @param nodeId: The id of the node
|
||||
* @param fieldName: The name of the field
|
||||
* @param fieldTemplate The template of the node field, required to build the form element
|
||||
* @param draggableRef The ref of the draggable HTML element
|
||||
* @param dragHandleRef The ref of the drag handle HTML element
|
||||
@@ -463,7 +523,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
* @returns Whether the node field is currently being dragged
|
||||
*/
|
||||
export const useNodeFieldDnd = (
|
||||
fieldIdentifier: FieldIdentifier,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate,
|
||||
draggableRef: RefObject<HTMLElement>,
|
||||
dragHandleRef: RefObject<HTMLElement>
|
||||
@@ -481,12 +542,7 @@ export const useNodeFieldDnd = (
|
||||
draggable({
|
||||
element: draggableElement,
|
||||
dragHandle: dragHandleElement,
|
||||
getInitialData: () => {
|
||||
const { nodeId, fieldName } = fieldIdentifier;
|
||||
const { type } = fieldTemplate;
|
||||
const element = buildNodeFieldElement(nodeId, fieldName, type);
|
||||
return buildFormElementDndData(element);
|
||||
},
|
||||
getInitialData: () => buildNodeFieldDndData(nodeId, fieldName, fieldTemplate),
|
||||
onDragStart: () => {
|
||||
setIsDragging(true);
|
||||
},
|
||||
@@ -495,7 +551,7 @@ export const useNodeFieldDnd = (
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [dragHandleRef, draggableRef, fieldIdentifier, fieldTemplate]);
|
||||
}, [dragHandleRef, draggableRef, fieldName, fieldTemplate, nodeId]);
|
||||
|
||||
return isDragging;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
|
||||
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import { InvocationNodeNotesTextarea } from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeNotesTextarea';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
@@ -36,7 +36,7 @@ export default memo(InspectorDetailsTab);
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const version = useNodeVersion(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const needsUpdate = useNodeNeedsUpdate(nodeId);
|
||||
|
||||
return (
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -37,7 +37,7 @@ const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const nes = useNodeExecutionState(nodeId);
|
||||
|
||||
if (!nes || nes.outputs.length === 0) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -14,8 +14,8 @@ type Props = {
|
||||
|
||||
const InspectorTabEditableNodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const onChange = useCallback(
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -29,7 +29,7 @@ export default memo(NodeTemplateInspector);
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
|
||||
return <DataViewer data={template} label={t('nodes.nodeTemplate')} bg="base.850" color="base.200" />;
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -0,0 +1,445 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Divider,
|
||||
Flex,
|
||||
ListItem,
|
||||
Spacer,
|
||||
Text,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { ExternalLink } from 'features/gallery/components/ImageViewer/NoContentForViewer';
|
||||
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isReadyToDoValidationRun,
|
||||
$isSelectingOutputNode,
|
||||
$outputNodeId,
|
||||
$validationRunBatchId,
|
||||
usePublishInputs,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleOrThrow } from 'features/nodes/hooks/useInputFieldUserTitleOrThrow';
|
||||
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
|
||||
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
|
||||
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
|
||||
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowLineRightBold, PiLightningFill, PiXBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const PublishWorkflowPanelContent = memo(() => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} h="full">
|
||||
<ButtonGroup isAttached={false} size="sm" variant="ghost">
|
||||
<Spacer />
|
||||
<CancelPublishButton />
|
||||
<PublishWorkflowButton />
|
||||
</ButtonGroup>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
<OutputFields />
|
||||
<PublishableInputFields />
|
||||
<UnpublishableInputFields />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishWorkflowPanelContent.displayName = 'PublishWorkflowPanelContent';
|
||||
|
||||
const OutputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Flex alignItems="center">
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
|
||||
<Spacer />
|
||||
<SelectOutputNodeButton variant="link" size="sm" />
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
{!outputNodeId && (
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
{t('workflows.builder.noOutputNodeSelected')}
|
||||
</Text>
|
||||
)}
|
||||
{outputNodeId && <OutputFieldsContent outputNodeId={outputNodeId} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
OutputFields.displayName = 'OutputFields';
|
||||
|
||||
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
|
||||
const outputFieldNames = useOutputFieldNames(outputNodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
{outputFieldNames.map((fieldName) => (
|
||||
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
OutputFieldsContent.displayName = 'OutputFieldsContent';
|
||||
|
||||
const PublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.publishable.length === 0) {
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.noPublishableInputs')}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowInputs')}</Text>
|
||||
<Divider />
|
||||
{inputs.publishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishableInputFields.displayName = 'PublishableInputFields';
|
||||
|
||||
const UnpublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.unpublishable.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.unpublishableInputs')}
|
||||
</Text>
|
||||
<Divider />
|
||||
{inputs.unpublishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
|
||||
|
||||
const SelectOutputNodeButton = memo((props: ButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const onClick = useCallback(() => {
|
||||
$outputNodeId.set(null);
|
||||
$isSelectingOutputNode.set(true);
|
||||
}, []);
|
||||
return (
|
||||
<Button
|
||||
leftIcon={<PiArrowLineRightBold />}
|
||||
isDisabled={isSelectingOutputNode}
|
||||
tooltip={isSelectingOutputNode ? t('workflows.builder.selectingOutputNodeDesc') : undefined}
|
||||
onClick={onClick}
|
||||
{...props}
|
||||
>
|
||||
{isSelectingOutputNode
|
||||
? t('workflows.builder.selectingOutputNode')
|
||||
: outputNodeId
|
||||
? t('workflows.builder.changeOutputNode')
|
||||
: t('workflows.builder.selectOutputNode')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
SelectOutputNodeButton.displayName = 'SelectOutputNodeButton';
|
||||
|
||||
const CancelPublishButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(false);
|
||||
$isSelectingOutputNode.set(false);
|
||||
$outputNodeId.set(null);
|
||||
}, []);
|
||||
return (
|
||||
<Button leftIcon={<PiXBold />} onClick={onClick}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
CancelPublishButton.displayName = 'CancelDeployButton';
|
||||
|
||||
const PublishWorkflowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isReadyToDoValidationRun = useStore($isReadyToDoValidationRun);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const inputs = usePublishInputs();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
const projectUrl = useStore($projectUrl);
|
||||
|
||||
const enqueue = useEnqueueWorkflows();
|
||||
const onClick = useCallback(async () => {
|
||||
const result = await withResultAsync(() => enqueue(true, true));
|
||||
if (result.isErr()) {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_FAILED',
|
||||
status: 'error',
|
||||
title: t('workflows.builder.publishFailed'),
|
||||
description: t('workflows.builder.publishFailedDesc'),
|
||||
duration: null,
|
||||
});
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
} else {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_SUCCESSFUL',
|
||||
status: 'success',
|
||||
title: t('workflows.builder.publishSuccess'),
|
||||
description: (
|
||||
<Trans
|
||||
i18nKey="workflows.builder.publishSuccessDesc"
|
||||
components={{
|
||||
LinkComponent: <ExternalLink href={projectUrl ?? ''} />,
|
||||
}}
|
||||
/>
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
assert(result.value.enqueueResult.batch.batch_id);
|
||||
$validationRunBatchId.set(result.value.enqueueResult.batch.batch_id);
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
}
|
||||
}, [enqueue, projectUrl, t]);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={outputNodeId !== null && !isSelectingOutputNode}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
leftIcon={<PiLightningFill />}
|
||||
isDisabled={
|
||||
!allowPublishWorkflows ||
|
||||
!isReadyToEnqueue ||
|
||||
!isWorkflowSaved ||
|
||||
hasBatchOrGeneratorNodes ||
|
||||
!isReadyToDoValidationRun ||
|
||||
!(outputNodeId !== null && !isSelectingOutputNode)
|
||||
}
|
||||
onClick={onClick}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
PublishWorkflowButton.displayName = 'DoValidationRunButton';
|
||||
|
||||
const NodeInputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldUserTitle = useInputFieldUserTitleOrThrow(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldUserTitle || fieldTemplateTitle}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeInputFieldPreview.displayName = 'NodeInputFieldPreview';
|
||||
|
||||
const NodeOutputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate.title}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
|
||||
|
||||
export const StartPublishFlowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(true);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={true}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
onClick={onClick}
|
||||
leftIcon={<PiLightningFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
isDisabled={!allowPublishWorkflows || !isReadyToEnqueue || !isWorkflowSaved || hasBatchOrGeneratorNodes}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
|
||||
StartPublishFlowButton.displayName = 'StartPublishFlowButton';
|
||||
|
||||
const PublishTooltip = memo(
|
||||
({
|
||||
isWorkflowSaved,
|
||||
hasBatchOrGeneratorNodes,
|
||||
isReadyToEnqueue,
|
||||
hasOutputNode,
|
||||
hasPublishableInputs,
|
||||
hasUnpublishableInputs,
|
||||
children,
|
||||
}: PropsWithChildren<{
|
||||
isWorkflowSaved: boolean;
|
||||
hasBatchOrGeneratorNodes: boolean;
|
||||
isReadyToEnqueue: boolean;
|
||||
hasOutputNode: boolean;
|
||||
hasPublishableInputs: boolean;
|
||||
hasUnpublishableInputs: boolean;
|
||||
}>) => {
|
||||
const { t } = useTranslation();
|
||||
const warnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
if (!hasPublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasNoPublishableInputFields'));
|
||||
}
|
||||
if (hasUnpublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasUnpublishableInputFields'));
|
||||
}
|
||||
return _warnings;
|
||||
}, [hasPublishableInputs, hasUnpublishableInputs, t]);
|
||||
const errors = useMemo(() => {
|
||||
const _errors: string[] = [];
|
||||
if (!isWorkflowSaved) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasUnsavedChanges'));
|
||||
}
|
||||
if (hasBatchOrGeneratorNodes) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasBatchOrGeneratorNodes'));
|
||||
}
|
||||
if (!isReadyToEnqueue) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasInvalidGraph'));
|
||||
}
|
||||
if (!hasOutputNode) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasNoOutputNode'));
|
||||
}
|
||||
return _errors;
|
||||
}, [hasBatchOrGeneratorNodes, hasOutputNode, isReadyToEnqueue, isWorkflowSaved, t]);
|
||||
|
||||
if (errors.length === 0 && warnings.length === 0) {
|
||||
return children;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
{errors.length > 0 && (
|
||||
<>
|
||||
<Text color="error.700" fontWeight="semibold">
|
||||
{t('workflows.builder.cannotPublish')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{errors.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
{warnings.length > 0 && (
|
||||
<>
|
||||
<Text color="warning.700" fontWeight="semibold">
|
||||
{t('workflows.builder.publishWarnings')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{warnings.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
);
|
||||
PublishTooltip.displayName = 'PublishTooltip';
|
||||
@@ -0,0 +1,23 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockBold } from 'react-icons/pi';
|
||||
|
||||
export const LockedWorkflowIcon = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
|
||||
<IconButton
|
||||
size="sm"
|
||||
cursor="not-allowed"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('workflows.builder.publishedWorkflowsLocked')}
|
||||
icon={<PiLockBold />}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
LockedWorkflowIcon.displayName = 'LockedWorkflowIcon';
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@@ -39,13 +40,12 @@ export const WorkflowLibrarySideNav = () => {
|
||||
const { t } = useTranslation();
|
||||
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
return (
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<Flex flexDir="column" w="full" pb={2} gap={2}>
|
||||
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
|
||||
{categoryOptions.includes('project') && (
|
||||
<Collapse in={view === 'yours' || view === 'shared' || view === 'private'}>
|
||||
@@ -60,6 +60,9 @@ export const WorkflowLibrarySideNav = () => {
|
||||
</Flex>
|
||||
</Collapse>
|
||||
)}
|
||||
{allowPublishWorkflows && (
|
||||
<WorkflowLibraryViewButton view="published">{t('workflows.published')}</WorkflowLibraryViewButton>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
|
||||
<BrowseWorkflowsButton />
|
||||
|
||||
@@ -36,6 +36,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
|
||||
return ['user'];
|
||||
case 'shared':
|
||||
return ['project'];
|
||||
case 'published':
|
||||
return ['user', 'project', 'default'];
|
||||
default:
|
||||
assert<Equals<typeof view, never>>(false);
|
||||
}
|
||||
@@ -66,6 +68,7 @@ const useInfiniteQueryAry = () => {
|
||||
query: debouncedSearchTerm,
|
||||
tags: view === 'defaults' ? selectedTags : [],
|
||||
has_been_opened: getHasBeenOpened(view),
|
||||
is_published: view === 'published' ? true : undefined,
|
||||
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
|
||||
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
|
||||
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
|
||||
import { selectWorkflowId, workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
@@ -54,7 +55,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
position="relative"
|
||||
role="button"
|
||||
onClick={handleClickLoad}
|
||||
cursor="pointer"
|
||||
bg="base.750"
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
@@ -81,7 +81,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
<Flex gap={2} alignItems="flex-start" justifyContent="space-between" w="full">
|
||||
<Text noOfLines={2}>{workflow.name}</Text>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{isActive && (
|
||||
{isActive && !workflow.is_published && (
|
||||
<Badge
|
||||
color="invokeBlue.400"
|
||||
borderColor="invokeBlue.700"
|
||||
@@ -93,6 +93,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
{t('workflows.opened')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.is_published && (
|
||||
<Badge
|
||||
color="invokeGreen.400"
|
||||
borderColor="invokeGreen.700"
|
||||
borderWidth={1}
|
||||
bg="transparent"
|
||||
flexShrink={0}
|
||||
variant="subtle"
|
||||
>
|
||||
{t('workflows.builder.published')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
|
||||
{workflow.category === 'default' && (
|
||||
<Image
|
||||
@@ -119,8 +131,10 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</Text>
|
||||
)}
|
||||
<Spacer />
|
||||
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
|
||||
{workflow.category !== 'default' && (
|
||||
{workflow.category === 'default' && !workflow.is_published && (
|
||||
<ViewWorkflow workflowId={workflow.workflow_id} />
|
||||
)}
|
||||
{workflow.category !== 'default' && !workflow.is_published && (
|
||||
<>
|
||||
<EditWorkflow workflowId={workflow.workflow_id} />
|
||||
<DownloadWorkflow workflowId={workflow.workflow_id} />
|
||||
@@ -128,6 +142,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</>
|
||||
)}
|
||||
{workflow.category === 'project' && <ShareWorkflowButton workflow={workflow} />}
|
||||
{workflow.is_published && <LockedWorkflowIcon />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
|
||||
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -8,12 +11,15 @@ import WorkflowJSONTab from './WorkflowJSONTab';
|
||||
|
||||
const WorkflowFieldsLinearViewPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
return (
|
||||
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
<Tab>{t('workflows.builder.builder')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
<Spacer />
|
||||
{allowPublishWorkflows && <StartPublishFlowButton />}
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="full" pt={2}>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user